Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
9a01f8fa
Unverified
Commit
9a01f8fa
authored
Jun 12, 2022
by
Rick Ho
Committed by
GitHub
Jun 12, 2022
Browse files
Merge pull request #120 from laekov/ddp-bcast-fix
Fix Broadcast rank bug in DGDP
parents
670e1407
dd68fd78
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
3 deletions
+16
-3
fmoe/distributed.py
fmoe/distributed.py
+3
-2
fmoe/utils.py
fmoe/utils.py
+11
-0
tests/test_numerical.py
tests/test_numerical.py
+2
-1
No files found.
fmoe/distributed.py
View file @
9a01f8fa
...
...
@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training
import
torch
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
,
get_rank_0_in_comm
class
DistributedGroupedDataParallel
(
nn
.
Module
):
...
...
@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module):
comm
=
self
.
comms
[
dp_comm
]
datas
=
[
p
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
datas
)
torch
.
distributed
.
broadcast
(
coalesced
,
0
,
group
=
comm
)
torch
.
distributed
.
broadcast
(
coalesced
,
get_rank_0_in_comm
(
comm
),
group
=
comm
)
torch
.
cuda
.
synchronize
()
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
for
d
,
s
in
zip
(
datas
,
synced
):
...
...
fmoe/utils.py
View file @
9a01f8fa
r
"""
Utils to play with PyTorch.
"""
import
torch
import
torch.distributed
as
dist
...
...
@@ -28,3 +29,13 @@ def get_torch_default_comm():
except
Exception
as
_
:
pass
raise
RuntimeError
(
"Unsupported PyTorch version"
)
def
get_rank_0_in_comm
(
comm
):
world_size
=
dist
.
get_world_size
(
comm
)
x
=
torch
.
tensor
([
dist
.
get_rank
()],
dtype
=
torch
.
int64
,
device
=
'cuda'
)
ys
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
ys
,
x
,
group
=
comm
)
root_rank
=
ys
[
0
].
item
()
return
root_rank
tests/test_numerical.py
View file @
9a01f8fa
...
...
@@ -71,7 +71,7 @@ class MyMoE(FMoE):
d_model
=
d_model
,
gate
=
NaiveGate
,
world_size
=
world_size
,
mp
_group
=
mp_group
,
slice
_group
=
mp_group
,
top_k
=
top_k
,
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
...
...
@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
model
=
MyModule
().
cuda
()
model_ddp
=
LocalDDP
(
deepcopy
(
model
),
mp_group
=
mp_group
,
dp_group
=
dp_group
,
world_group
=
world_group
)
model
=
deepcopy
(
model_ddp
.
module
)
model
.
set_comm
()
model_ddp
.
module
.
set_comm
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment