Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
ee88bb19
Unverified
Commit
ee88bb19
authored
Oct 16, 2020
by
msbaines
Committed by
GitHub
Oct 16, 2020
Browse files
[feat] moe: annotate expert params (#140)
The expert annotation is used by clip_grads and DDP.
parent
d99c445a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
0 deletions
+13
-0
fairscale/nn/moe/moelayer.py
fairscale/nn/moe/moelayer.py
+2
-0
tests/nn/moe/test_moelayer.py
tests/nn/moe/test_moelayer.py
+11
-0
No files found.
fairscale/nn/moe/moelayer.py
View file @
ee88bb19
...
@@ -60,6 +60,8 @@ class MOELayer(Base):
...
@@ -60,6 +60,8 @@ class MOELayer(Base):
self
.
gate
=
gate
self
.
gate
=
gate
self
.
expert
=
expert
self
.
expert
=
expert
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
for
p
in
expert
.
parameters
():
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
...
...
tests/nn/moe/test_moelayer.py
View file @
ee88bb19
...
@@ -45,6 +45,17 @@ def test_create(device):
...
@@ -45,6 +45,17 @@ def test_create(device):
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_expert_params
(
device
):
model_dim
=
8
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
for
p
in
expert
.
parameters
():
assert
p
.
expert
is
True
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward
(
device
):
def
test_forward
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment