Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b8965a78
Commit
b8965a78
authored
Apr 17, 2019
by
Michael Carilli
Browse files
Option to elide unflattening copy
parent
887a50bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
2 deletions
+11
-2
apex/parallel/distributed.py
apex/parallel/distributed.py
+11
-2
No files found.
apex/parallel/distributed.py
View file @
b8965a78
...
@@ -443,6 +443,8 @@ class DistributedDataParallel(Module):
...
@@ -443,6 +443,8 @@ class DistributedDataParallel(Module):
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error."
)
"allreduce buffer. This is almost certainly an error."
)
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
grad
.
data
=
view
else
:
else
:
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
multi_tensor_applier
(
multi_tensor_applier
(
...
@@ -456,7 +458,10 @@ class DistributedDataParallel(Module):
...
@@ -456,7 +458,10 @@ class DistributedDataParallel(Module):
def
allreduce_fallback
(
self
):
def
allreduce_fallback
(
self
):
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
if
self
.
retain_allreduce_buffers
:
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
else
:
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
split_buckets
=
split_half_float_double
(
grads
)
split_buckets
=
split_half_float_double
(
grads
)
...
@@ -482,7 +487,11 @@ class DistributedDataParallel(Module):
...
@@ -482,7 +487,11 @@ class DistributedDataParallel(Module):
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error."
)
"bucket slot. This is almost certainly an error."
)
self
.
buckets
[
bucket_idx
][
bucket_loc
]
=
param
.
grad
.
data
if
self
.
retain_allreduce_buffers
:
self
.
buckets
[
bucket_idx
][
bucket_loc
]
=
param
.
grad
else
:
self
.
buckets
[
bucket_idx
][
bucket_loc
]
=
param
.
grad
.
data
self
.
buckets_ready_size
[
bucket_idx
]
+=
1
self
.
buckets_ready_size
[
bucket_idx
]
+=
1
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment