Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
da11cb76
Commit
da11cb76
authored
Feb 19, 2021
by
Rick Ho
Browse files
broadcast distributed parameters
parent
afd43f51
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
0 deletions
+29
-0
fmoe/distributed.py
fmoe/distributed.py
+29
-0
No files found.
fmoe/distributed.py
View file @
da11cb76
...
...
@@ -74,7 +74,36 @@ class DistributedGroupedDataParallel(nn.Module):
g
.
copy_
(
s
)
self
.
allreduce_params
=
allreduce_params
self
.
_sync_params
()
def
_sync_params
(
self
):
groups
=
dict
()
for
p
in
self
.
module
.
parameters
():
if
not
p
.
requires_grad
or
p
.
grad
is
None
:
continue
if
hasattr
(
p
,
'dp_comm'
):
dp_comm
=
p
.
dp_comm
else
:
dp_comm
=
'dp'
group_key
=
(
dp_comm
,
p
.
dtype
)
if
group_key
not
in
groups
:
groups
[
group_key
]
=
[
p
]
else
:
groups
[
group_key
].
append
(
p
)
for
(
dp_comm
,
dtype
),
group
in
groups
.
items
():
if
dp_comm
not
in
self
.
comms
:
continue
comm
=
self
.
comms
[
dp_comm
]
datas
=
[
p
.
data
for
p
in
group
]
coalesced
=
_flatten_dense_tensors
(
datas
)
if
fp32_allreduce
and
dtype
!=
torch
.
float32
:
coalesced
=
coalesced
.
float
()
torch
.
distributed
.
broadcast
(
coalesced
,
0
,
group
=
comm
)
torch
.
cuda
.
synchronize
()
synced
=
_unflatten_dense_tensors
(
coalesced
,
datas
)
for
d
,
s
in
zip
(
datas
,
synced
):
d
.
copy_
(
s
)
def
forward
(
self
,
*
args
,
**
kwargs
):
r
'''
Directly call the module's forward function.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment