Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ed14f39c
Commit
ed14f39c
authored
Sep 05, 2018
by
Michael Carilli
Browse files
Fixing needs_refresh logic to allow multiple forwards between each backward
parent
586c507e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
16 deletions
+13
-16
apex/parallel/distributed.py
apex/parallel/distributed.py
+13
-16
No files found.
apex/parallel/distributed.py
View file @
ed14f39c
...
...
@@ -133,7 +133,7 @@ class DistributedDataParallel(Module):
self
.
shared_param
=
shared_param
self
.
message_size
=
message_size
#reference to last iterations parameters to see if anything has changed
#
reference to last iterations parameters to see if anything has changed
self
.
param_refs
=
[]
self
.
reduction_stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -162,13 +162,13 @@ class DistributedDataParallel(Module):
def
create_hooks
(
self
):
#all reduce gradient hook
#
all reduce gradient hook
def
allreduce_params
():
if
not
self
.
needs_reduction
:
return
self
.
needs_reduction
=
False
#parameter ordering refresh
#
parameter ordering refresh
if
self
.
needs_refresh
and
not
self
.
shared_param
:
t_record
=
torch
.
cuda
.
IntTensor
(
self
.
record
)
dist
.
broadcast
(
t_record
,
0
)
...
...
@@ -267,21 +267,18 @@ class DistributedDataParallel(Module):
param_list
=
[
param
for
param
in
self
.
module
.
parameters
()
if
param
.
requires_grad
]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
#Parentheses are not necessary for correct order of operations, but make the intent clearer.
if
(
not
self
.
param_refs
)
or
self
.
shared_param
:
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if
(
not
self
.
param_refs
)
or
self
.
shared_param
or
(
len
(
param_list
)
!=
len
(
self
.
param_refs
))
or
any
([
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
param_refs
)]):
self
.
needs_refresh
=
True
else
:
self
.
needs_refresh
=
(
(
len
(
param_list
)
!=
len
(
self
.
param_refs
))
or
any
(
[
param1
is
not
param2
for
param1
,
param2
in
zip
(
param_list
,
self
.
param_refs
)]))
if
self
.
needs_refresh
:
self
.
record
=
[]
if
self
.
needs_refresh
:
self
.
record
=
[]
self
.
param_state
=
[
0
for
i
in
range
(
len
(
param_list
))]
self
.
param_refs
=
param_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment