Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
56c1bd63
Commit
56c1bd63
authored
Feb 04, 2021
by
Rick Ho
Browse files
fix no grad after all-gather bug
parent
d83234b0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
20 deletions
+32
-20
fmoe/functions.py
fmoe/functions.py
+16
-0
fmoe/layers.py
fmoe/layers.py
+12
-13
fmoe/megatron.py
fmoe/megatron.py
+4
-7
No files found.
fmoe/functions.py
View file @
56c1bd63
...
...
@@ -143,3 +143,19 @@ class MOEGather(Function):
else
:
global_grad_out_buf
=
grad_out_buf
return
global_grad_out_buf
,
None
,
None
,
None
,
None
,
None
class
AllGather
(
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
,
rank
,
world_size
,
group
):
tensor_list
=
[
torch
.
empty_like
(
inp
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
inp
,
group
=
group
)
torch
.
cuda
.
synchronize
()
output
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
ctx
.
args
=
rank
,
inp
.
shape
[
0
]
return
output
@
staticmethod
def
backward
(
ctx
,
grad_out
):
rank
,
dim0
=
ctx
.
args
return
grad_out
[
rank
*
dim0
:(
rank
+
1
)
*
dim0
],
None
,
None
,
None
fmoe/layers.py
View file @
56c1bd63
...
...
@@ -69,8 +69,6 @@ class FMoETransformerMLP(nn.Module):
d_model
=
1024
,
d_hidden
=
4096
,
world_size
=
1
,
model_parallel_size
=
1
,
model_parallel_rank
=
1
,
mp_group
=
None
,
activation
=
torch
.
nn
.
functional
.
gelu
,
top_k
=
2
,
...
...
@@ -81,9 +79,13 @@ class FMoETransformerMLP(nn.Module):
self
.
d_model
=
d_model
self
.
d_hidden
=
d_hidden
self
.
world_size
=
world_size
self
.
model_parallel_size
=
model_parallel_size
self
.
model_parallel_rank
=
model_parallel_rank
self
.
mp_group
=
mp_group
if
mp_group
is
None
:
self
.
mp_size
=
1
self
.
mp_rank
=
0
else
:
self
.
mp_size
=
mp_group
.
size
()
self
.
mp_rank
=
mp_group
.
rank
()
self
.
activation
=
activation
self
.
pre_lnorm
=
pre_lnorm
self
.
top_k
=
top_k
...
...
@@ -104,10 +106,10 @@ class FMoETransformerMLP(nn.Module):
original_shape
=
inp
.
shape
inp
=
inp
.
reshape
(
-
1
,
self
.
d_model
)
if
self
.
m
odel_parallel
_size
>
1
:
if
self
.
m
p
_size
>
1
:
B
:
int
=
inp
.
shape
[
0
]
local_batch_size
=
B
//
self
.
m
odel_parallel
_size
batch_start
=
local_batch_size
*
self
.
m
odel_parallel
_rank
local_batch_size
=
B
//
self
.
m
p
_size
batch_start
=
local_batch_size
*
self
.
m
p
_rank
batch_end
=
min
(
batch_start
+
local_batch_size
,
B
)
inp
=
inp
[
batch_start
:
batch_end
]
...
...
@@ -138,11 +140,8 @@ class FMoETransformerMLP(nn.Module):
if
not
self
.
pre_lnorm
:
output
=
self
.
layer_norm
(
output
)
if
self
.
model_parallel_size
>
1
:
world_size
=
self
.
model_parallel_size
tensor_list
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
output
,
group
=
self
.
mp_group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=
1
)
if
self
.
mp_size
>
1
:
output
=
AllGather
.
apply
(
output
,
self
.
mp_rank
,
self
.
mp_size
,
self
.
mp_group
)
return
output
.
reshape
(
original_shape
),
self
.
bias
fmoe/megatron.py
View file @
56c1bd63
from
.layers
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
def
create_moe_mlp
(
args
,
model_parallel_rank
,
group
):
def
create_moe_mlp
(
args
,
group
):
assert
(
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
...
...
@@ -14,9 +15,7 @@ def create_moe_mlp(args, model_parallel_rank, group):
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_size
*
4
,
world_size
=
world_size
,
model_parallel_size
=
args
.
model_parallel_size
,
model_parallel_rank
=
model_parallel_rank
,
mp_group
=
group
,
mp_group
=
group
)
for
p
in
fmoe
.
gate
.
parameters
():
setattr
(
p
,
'shared'
,
True
)
...
...
@@ -38,9 +37,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_rank
(),
mpu
.
get_model_parallel_group
())
l
.
mlp
=
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_group
())
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment