Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
dd68fd78
Commit
dd68fd78
authored
Jun 10, 2022
by
Rick Ho
Browse files
update ddp with get first rank and tests
parent
670e1407
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
3 deletions
+16
-3
fmoe/distributed.py
fmoe/distributed.py
+3
-2
fmoe/utils.py
fmoe/utils.py
+11
-0
tests/test_numerical.py
tests/test_numerical.py
+2
-1
No files found.
fmoe/distributed.py
View file @
dd68fd78
...
@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training
...
@@ -4,7 +4,7 @@ Supportive modules to conduct distributed training
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
.utils
import
get_torch_default_comm
from
.utils
import
get_torch_default_comm
,
get_rank_0_in_comm
class
DistributedGroupedDataParallel
(
nn
.
Module
):
class
DistributedGroupedDataParallel
(
nn
.
Module
):
...
@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module):
...
@@ -97,7 +97,8 @@ class DistributedGroupedDataParallel(nn.Module):
comm
=
self
.
comms
[
dp_comm
]
comm
=
self
.
comms
[
dp_comm
]
datas
=
[
p
.
data
for
p
in
group
]
datas
=
[
p
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
datas
)
coalesced
=
_flatten_dense_tensors
(
datas
)
torch
.
distributed
.
broadcast
(
coalesced
,
0
,
group
=
comm
)
torch
.
distributed
.
broadcast
(
coalesced
,
get_rank_0_in_comm
(
comm
),
group
=
comm
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
for
d
,
s
in
zip
(
datas
,
synced
):
for
d
,
s
in
zip
(
datas
,
synced
):
...
...
fmoe/utils.py
View file @
dd68fd78
r
"""
r
"""
Utils to play with PyTorch.
Utils to play with PyTorch.
"""
"""
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -28,3 +29,13 @@ def get_torch_default_comm():
...
@@ -28,3 +29,13 @@ def get_torch_default_comm():
except
Exception
as
_
:
except
Exception
as
_
:
pass
pass
raise
RuntimeError
(
"Unsupported PyTorch version"
)
raise
RuntimeError
(
"Unsupported PyTorch version"
)
def
get_rank_0_in_comm
(
comm
):
world_size
=
dist
.
get_world_size
(
comm
)
x
=
torch
.
tensor
([
dist
.
get_rank
()],
dtype
=
torch
.
int64
,
device
=
'cuda'
)
ys
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
ys
,
x
,
group
=
comm
)
root_rank
=
ys
[
0
].
item
()
return
root_rank
tests/test_numerical.py
View file @
dd68fd78
...
@@ -71,7 +71,7 @@ class MyMoE(FMoE):
...
@@ -71,7 +71,7 @@ class MyMoE(FMoE):
d_model
=
d_model
,
d_model
=
d_model
,
gate
=
NaiveGate
,
gate
=
NaiveGate
,
world_size
=
world_size
,
world_size
=
world_size
,
mp
_group
=
mp_group
,
slice
_group
=
mp_group
,
top_k
=
top_k
,
top_k
=
top_k
,
)
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
self
.
experts
=
_Expert
(
num_expert
,
d_model
,
d_hidden
,
activation
)
...
@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
...
@@ -344,6 +344,7 @@ def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
model
=
MyModule
().
cuda
()
model
=
MyModule
().
cuda
()
model_ddp
=
LocalDDP
(
deepcopy
(
model
),
model_ddp
=
LocalDDP
(
deepcopy
(
model
),
mp_group
=
mp_group
,
dp_group
=
dp_group
,
world_group
=
world_group
)
mp_group
=
mp_group
,
dp_group
=
dp_group
,
world_group
=
world_group
)
model
=
deepcopy
(
model_ddp
.
module
)
model
.
set_comm
()
model
.
set_comm
()
model_ddp
.
module
.
set_comm
()
model_ddp
.
module
.
set_comm
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment