Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
47da14a0
Commit
47da14a0
authored
Jul 02, 2019
by
Michael Carilli
Browse files
cosmetic
parent
8a32e428
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
apex/parallel/distributed.py
apex/parallel/distributed.py
+9
-10
No files found.
apex/parallel/distributed.py
View file @
47da14a0
...
@@ -231,7 +231,7 @@ class DistributedDataParallel(Module):
...
@@ -231,7 +231,7 @@ class DistributedDataParallel(Module):
self
.
module
=
module
self
.
module
=
module
self
.
_disable_allreduce
=
False
self
.
_disable_allreduce
=
False
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
for
param
in
self
.
module
.
parameters
():
for
param
in
self
.
module
.
parameters
():
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
...
@@ -277,9 +277,9 @@ class DistributedDataParallel(Module):
...
@@ -277,9 +277,9 @@ class DistributedDataParallel(Module):
def
disable_allreduce
(
self
):
def
disable_allreduce
(
self
):
self
.
_disable_allreduce
=
True
self
.
_disable_allreduce
=
True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# regenerate their bucket structures to match.
def
sync_bucket_structure
(
self
):
def
sync_bucket_structure
(
self
):
# Append leftover buckets
# Append leftover buckets
for
tmp_bucket
in
self
.
tmp_buckets
:
for
tmp_bucket
in
self
.
tmp_buckets
:
...
@@ -356,7 +356,6 @@ class DistributedDataParallel(Module):
...
@@ -356,7 +356,6 @@ class DistributedDataParallel(Module):
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
if
self
.
prof
:
if
self
.
prof
:
torch
.
cuda
.
nvtx
.
range_push
(
"allreduce_hook"
)
torch
.
cuda
.
nvtx
.
range_push
(
"allreduce_hook"
)
...
@@ -371,8 +370,8 @@ class DistributedDataParallel(Module):
...
@@ -371,8 +370,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately.
# Float, half, and double tensors are grouped into buckets separately.
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
ship_tmp_bucket
=
False
ship_tmp_bucket
=
False
if
self
.
custom_allreduce_triggers
:
if
self
.
custom_allreduce_triggers
:
...
@@ -389,20 +388,20 @@ class DistributedDataParallel(Module):
...
@@ -389,20 +388,20 @@ class DistributedDataParallel(Module):
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_numels
[
current_type
]
=
0
self
.
tmp_numels
[
current_type
]
=
0
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
else
:
else
:
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
self
.
comm_ready_buckets
(
param
)
self
.
comm_ready_buckets
(
param
)
if
self
.
prof
:
if
self
.
prof
:
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
grad_acc
.
register_hook
(
allreduce_hook
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
grad_accs
.
append
(
grad_acc
)
self
.
grad_accs
.
append
(
grad_acc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment