Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
67c667f2
Commit
67c667f2
authored
Feb 04, 2021
by
Rick Ho
Browse files
ddp module for sophiscated hybrid parallel
parent
ea66e5e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
3 deletions
+82
-3
fmoe/distributed.py
fmoe/distributed.py
+64
-0
fmoe/megatron.py
fmoe/megatron.py
+18
-3
No files found.
fmoe/distributed.py
0 → 100644
View file @
67c667f2
import
torch
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
class
DistributedGroupedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
module
,
mp_group
=
None
,
dp_group
=
None
,
world_group
=
None
,
auto_allreduce
=
False
):
assert
not
auto_allreduce
,
'Automatic all-reduce is not implemented yet'
super
(
DistributedGroupedDataParallel
,
self
).
__init__
()
self
.
module
=
module
self
.
comms
=
dict
()
if
mp_group
is
not
None
:
self
.
comms
[
'mp'
]
=
mp_group
if
dp_group
is
not
None
:
self
.
comms
[
'dp'
]
=
dp_group
else
:
self
.
comms
[
'dp'
]
=
torch
.
distributed
.
distributed_c10d
.
_default_pg
if
world_group
is
None
:
self
.
comms
[
'world'
]
=
torch
.
distributed
.
distributed_c10d
.
_default_pg
else
:
self
.
comms
[
'world'
]
=
world_group
def
allreduce_params
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
groups
=
dict
()
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
continue
if
hasattr
(
p
,
'parallel_method'
):
pm
=
p
.
parallel_method
else
:
pm
=
'dp'
group_key
=
(
pm
,
p
.
dtype
)
if
group_key
not
in
groups
:
groups
[
group_key
]
=
[
p
]
else
:
groups
[
group_key
].
append
(
p
)
for
pm
,
dtype
in
groups
:
if
pm
not
in
self
.
comms
:
continue
group
=
groups
[
pm
,
dtype
]
comm
=
self
.
comms
[
pm
]
grads
=
[
p
.
grad
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
grads
)
if
fp32_allreduce
and
dtype
!=
torch
.
float32
:
coalesced
=
coalesced
.
float
()
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
comm
.
size
()
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
comm
)
torch
.
cuda
.
synchronize
()
if
not
no_scale
and
reduce_after
:
coalesced
/=
comm
.
size
()
synced
=
_unflatten_dense_tensors
(
coalesced
,
grads
)
for
g
,
s
in
zip
(
grads
,
synced
):
g
.
copy_
(
s
)
self
.
allreduce_params
=
allreduce_params
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
fmoe/megatron.py
View file @
67c667f2
from
.layers
import
FMoETransformerMLP
from
.layers
import
FMoETransformerMLP
from
.distributed
import
DistributedGroupedDataParallel
def
create_moe_mlp
(
args
,
model_parallel_rank
,
group
):
def
create_moe_mlp
(
args
,
model_parallel_rank
,
group
):
assert
(
assert
(
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
args
.
seq_length
*
args
.
batch_size
%
args
.
model_parallel_size
==
0
),
"Batch size x sequence length should be multiple of mp size"
),
"Batch size x sequence length should be multiple of mp size"
if
args
.
model_parallel_size
==
1
:
if
not
args
.
distributed_experts
:
world_size
=
1
world_size
=
1
else
:
else
:
world_size
=
args
.
world_size
world_size
=
args
.
world_size
...
@@ -21,7 +21,7 @@ def create_moe_mlp(args, model_parallel_rank, group):
...
@@ -21,7 +21,7 @@ def create_moe_mlp(args, model_parallel_rank, group):
return
fmoe
return
fmoe
def
fmoefy
(
model
,
num_experts
=
None
):
def
fmoefy
(
model
,
num_experts
=
None
,
distributed_experts
=
True
):
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
args
=
get_args
()
args
=
get_args
()
...
@@ -30,8 +30,23 @@ def fmoefy(model, num_experts=None):
...
@@ -30,8 +30,23 @@ def fmoefy(model, num_experts=None):
assert
(
assert
(
'num_experts'
in
args
'num_experts'
in
args
),
'num_experts should be specified in arguments or fmoefy function'
),
'num_experts should be specified in arguments or fmoefy function'
# Set distributed_experts to None to use default setting in args
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
for
l
in
model
.
language_model
.
transformer
.
layers
:
for
l
in
model
.
language_model
.
transformer
.
layers
:
l
.
mlp
=
create_moe_mlp
(
args
,
l
.
mlp
=
create_moe_mlp
(
args
,
mpu
.
get_model_parallel_rank
(),
mpu
.
get_model_parallel_rank
(),
mpu
.
get_model_parallel_group
())
mpu
.
get_model_parallel_group
())
return
model
return
model
class
DistributedDataParallel
(
DistributedGroupedDataParallel
):
def
__init__
(
self
,
module
):
from
megatron
import
mpu
super
(
DistributedDataParallel
,
self
).
__init__
(
module
,
mp_group
=
mpu
.
get_model_parallel_group
(),
dp_group
=
mpu
.
get_data_parallel_group
()
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment