Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
47da14a0
You need to sign in or sign up before continuing.
Commit
47da14a0
authored
Jul 02, 2019
by
Michael Carilli
Browse files
cosmetic
parent
8a32e428
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
apex/parallel/distributed.py
apex/parallel/distributed.py
+9
-10
No files found.
apex/parallel/distributed.py
View file @
47da14a0
...
@@ -231,7 +231,7 @@ class DistributedDataParallel(Module):
...
@@ -231,7 +231,7 @@ class DistributedDataParallel(Module):
self
.
module
=
module
self
.
module
=
module
self
.
_disable_allreduce
=
False
self
.
_disable_allreduce
=
False
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
if
self
.
_backend
==
self
.
backend_enum_holder
.
NCCL
:
for
param
in
self
.
module
.
parameters
():
for
param
in
self
.
module
.
parameters
():
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
assert
param
.
is_cuda
,
"NCCL backend only supports model parameters to be on GPU."
...
@@ -277,9 +277,9 @@ class DistributedDataParallel(Module):
...
@@ -277,9 +277,9 @@ class DistributedDataParallel(Module):
def
disable_allreduce
(
self
):
def
disable_allreduce
(
self
):
self
.
_disable_allreduce
=
True
self
.
_disable_allreduce
=
True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# regenerate their bucket structures to match.
def
sync_bucket_structure
(
self
):
def
sync_bucket_structure
(
self
):
# Append leftover buckets
# Append leftover buckets
for
tmp_bucket
in
self
.
tmp_buckets
:
for
tmp_bucket
in
self
.
tmp_buckets
:
...
@@ -356,7 +356,6 @@ class DistributedDataParallel(Module):
...
@@ -356,7 +356,6 @@ class DistributedDataParallel(Module):
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
def
allreduce_hook
(
*
unused
):
if
self
.
prof
:
if
self
.
prof
:
torch
.
cuda
.
nvtx
.
range_push
(
"allreduce_hook"
)
torch
.
cuda
.
nvtx
.
range_push
(
"allreduce_hook"
)
...
@@ -371,8 +370,8 @@ class DistributedDataParallel(Module):
...
@@ -371,8 +370,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately.
# Float, half, and double tensors are grouped into buckets separately.
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
current_type
=
self
.
param_type_to_tmp_i
[
param
.
type
()]
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
self
.
tmp_buckets
[
current_type
].
append
(
active_i
)
ship_tmp_bucket
=
False
ship_tmp_bucket
=
False
if
self
.
custom_allreduce_triggers
:
if
self
.
custom_allreduce_triggers
:
...
@@ -389,20 +388,20 @@ class DistributedDataParallel(Module):
...
@@ -389,20 +388,20 @@ class DistributedDataParallel(Module):
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
active_i_buckets
.
append
(
self
.
tmp_buckets
[
current_type
])
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_buckets
[
current_type
]
=
[]
self
.
tmp_numels
[
current_type
]
=
0
self
.
tmp_numels
[
current_type
]
=
0
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
Variable
.
_execution_engine
.
queue_callback
(
allreduce_params
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
else
:
else
:
if
not
self
.
callback_queued
:
if
not
self
.
callback_queued
:
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
Variable
.
_execution_engine
.
queue_callback
(
overlapping_backward_epilogue
)
self
.
callback_queued
=
True
self
.
callback_queued
=
True
self
.
comm_ready_buckets
(
param
)
self
.
comm_ready_buckets
(
param
)
if
self
.
prof
:
if
self
.
prof
:
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
grad_acc
.
register_hook
(
allreduce_hook
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
grad_accs
.
append
(
grad_acc
)
self
.
grad_accs
.
append
(
grad_acc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment