Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
25ac9897
"examples/community/run_tensorrt_controlnet.py" did not exist on "38adcd21bd9cc69f84710deb161c578b6478944f"
Commit
25ac9897
authored
Apr 23, 2019
by
Michael Carilli
Browse files
Moving flat allreduce buffer creation to main stream
parent
b8965a78
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
50 deletions
+74
-50
apex/parallel/distributed.py
apex/parallel/distributed.py
+74
-50
No files found.
apex/parallel/distributed.py
View file @
25ac9897
...
...
@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
self
.
delay_allreduce
=
delay_allreduce
self
.
message_size
=
message_size
self
.
main_stream
=
torch
.
cuda
.
current_stream
()
self
.
bucket_streams
=
[]
self
.
bucket_events
=
[]
...
...
@@ -411,9 +413,21 @@ class DistributedDataParallel(Module):
return
self
.
bucket_events
[
0
]
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
):
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
,
force_default_stream
):
tensor
=
flatten
(
bucket
)
if
force_default_stream
:
bucket_stream
=
self
.
main_stream
else
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce
=
tensor
if
self
.
allreduce_always_fp32
:
...
...
@@ -433,11 +447,30 @@ class DistributedDataParallel(Module):
if
self
.
allreduce_always_fp32
and
tensor
is
not
tensor_to_allreduce
:
tensor
.
copy_
(
tensor_to_allreduce
)
if
not
self
.
retain_allreduce_buffers
:
if
multi_tensor_applier
.
available
:
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
tensor
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
tensor
,
bucket
)):
buf
.
copy_
(
synced
)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must
# be synced on bucket_stream anyway.
# Also, we maintain a live reference to the returned tensor in allreduce_buffers.
# But this doesn't hurt.
tensor
.
record_stream
(
bucket_stream
)
# torch.cuda.synchronize()
return
tensor
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
=-
1
):
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
)
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
,
force_default_stream
=
False
):
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
,
force_default_stream
)
if
self
.
retain_allreduce_buffers
:
if
self
.
allreduce_buffers
[
bucket_idx
]
is
not
None
:
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
...
...
@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
grad
.
data
=
view
else
:
if
multi_tensor_applier
.
available
:
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
allreduced
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
allreduced
,
bucket
)):
buf
.
copy_
(
synced
)
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
# buf.copy_(synced)
def
allreduce_fallback
(
self
):
for
stream
,
event
in
zip
(
self
.
bucket_streams
,
self
.
bucket_events
):
stream
.
record_event
(
event
)
torch
.
cuda
.
current_stream
().
wait_event
(
event
)
if
self
.
retain_allreduce_buffers
:
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
else
:
...
...
@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
for
i
,
bucket
in
enumerate
(
split_buckets
):
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
)
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
,
force_default_stream
=
True
)
def
comm_ready_buckets
(
self
,
param
):
...
...
@@ -496,11 +525,6 @@ class DistributedDataParallel(Module):
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
bucket_idx
==
self
.
next_bucket
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
self
.
allreduce_maybe_retain
(
self
.
buckets
[
bucket_idx
],
bucket_idx
)
self
.
next_bucket
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment