Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
d83234b0
Commit
d83234b0
authored
Feb 04, 2021
by
Rick Ho
Browse files
use parallel label in gate
parent
67c667f2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
9 deletions
+13
-9
fmoe/distributed.py
fmoe/distributed.py
+8
-8
fmoe/layers.py
fmoe/layers.py
+2
-0
fmoe/megatron.py
fmoe/megatron.py
+2
-0
setup.py
setup.py
+1
-1
No files found.
fmoe/distributed.py
View file @
d83234b0
...
@@ -29,20 +29,20 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -29,20 +29,20 @@ class DistributedGroupedDataParallel(nn.Module):
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
continue
continue
if
hasattr
(
p
,
'
parallel_method
'
):
if
hasattr
(
p
,
'
dp_comm
'
):
pm
=
p
.
parallel_method
dp_comm
=
p
.
dp_comm
else
:
else
:
p
m
=
'dp'
dp_com
m
=
'dp'
group_key
=
(
p
m
,
p
.
dtype
)
group_key
=
(
dp_com
m
,
p
.
dtype
)
if
group_key
not
in
groups
:
if
group_key
not
in
groups
:
groups
[
group_key
]
=
[
p
]
groups
[
group_key
]
=
[
p
]
else
:
else
:
groups
[
group_key
].
append
(
p
)
groups
[
group_key
].
append
(
p
)
for
p
m
,
dtype
in
groups
:
for
dp_com
m
,
dtype
in
groups
:
if
p
m
not
in
self
.
comms
:
if
dp_com
m
not
in
self
.
comms
:
continue
continue
group
=
groups
[
p
m
,
dtype
]
group
=
groups
[
dp_com
m
,
dtype
]
comm
=
self
.
comms
[
p
m
]
comm
=
self
.
comms
[
dp_com
m
]
grads
=
[
p
.
grad
.
data
for
p
in
group
]
grads
=
[
p
.
grad
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
and
dtype
!=
torch
.
float32
:
if
fp32_allreduce
and
dtype
!=
torch
.
float32
:
...
...
fmoe/layers.py
View file @
d83234b0
...
@@ -92,6 +92,8 @@ class FMoETransformerMLP(nn.Module):
...
@@ -92,6 +92,8 @@ class FMoETransformerMLP(nn.Module):
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
)
self
.
gate
=
FMoENaiveGate
(
d_model
,
num_expert
,
world_size
,
top_k
)
self
.
gate
=
FMoENaiveGate
(
d_model
,
num_expert
,
world_size
,
top_k
)
for
p
in
self
.
gate
.
parameters
():
setattr
(
p
,
'dp_comm'
,
'world'
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
d_model
)
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
self
.
bias
=
torch
.
nn
.
parameter
.
Parameter
(
...
...
fmoe/megatron.py
View file @
d83234b0
...
@@ -18,6 +18,8 @@ def create_moe_mlp(args, model_parallel_rank, group):
...
@@ -18,6 +18,8 @@ def create_moe_mlp(args, model_parallel_rank, group):
model_parallel_rank
=
model_parallel_rank
,
model_parallel_rank
=
model_parallel_rank
,
mp_group
=
group
,
mp_group
=
group
,
)
)
for
p
in
fmoe
.
gate
.
parameters
():
setattr
(
p
,
'shared'
,
True
)
return
fmoe
return
fmoe
...
...
setup.py
View file @
d83234b0
...
@@ -29,7 +29,7 @@ if __name__ == '__main__':
...
@@ -29,7 +29,7 @@ if __name__ == '__main__':
}
}
)
)
],
],
version
=
'0.0.
1
'
,
version
=
'0.0.
2
'
,
cmdclass
=
{
cmdclass
=
{
'build_ext'
:
BuildExtension
'build_ext'
:
BuildExtension
})
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment