Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
25ac9897
Commit
25ac9897
authored
Apr 23, 2019
by
Michael Carilli
Browse files
Moving flat allreduce buffer creation to main stream
parent
b8965a78
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
50 deletions
+74
-50
apex/parallel/distributed.py
apex/parallel/distributed.py
+74
-50
No files found.
apex/parallel/distributed.py
View file @
25ac9897
...
@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
...
@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
self
.
delay_allreduce
=
delay_allreduce
self
.
delay_allreduce
=
delay_allreduce
self
.
message_size
=
message_size
self
.
message_size
=
message_size
self
.
main_stream
=
torch
.
cuda
.
current_stream
()
self
.
bucket_streams
=
[]
self
.
bucket_streams
=
[]
self
.
bucket_events
=
[]
self
.
bucket_events
=
[]
...
@@ -411,33 +413,64 @@ class DistributedDataParallel(Module):
...
@@ -411,33 +413,64 @@ class DistributedDataParallel(Module):
return
self
.
bucket_events
[
0
]
return
self
.
bucket_events
[
0
]
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
):
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
,
force_default_stream
):
tensor
=
flatten
(
bucket
)
tensor
=
flatten
(
bucket
)
tensor_to_allreduce
=
tensor
if
force_default_stream
:
bucket_stream
=
self
.
main_stream
else
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
if
self
.
allreduce_always_fp32
:
tensor_to_allreduce
=
tensor
tensor_to_allreduce
=
tensor
.
float
()
if
self
.
gradient_predivide_factor
!=
1.0
:
if
self
.
allreduce_always_fp32
:
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
tensor_to_allreduce
=
tensor
.
float
(
)
if
self
.
allreduce_different_streams
and
self
.
bucket_pgs
:
if
self
.
gradient_predivide_factor
!=
1.0
:
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
bucket_pgs
[
bucket_idx
])
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
else
:
dist
.
all_reduce
(
tensor_to_allreduce
)
if
self
.
allreduce_different_streams
and
self
.
bucket_pgs
:
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
bucket_pgs
[
bucket_idx
])
else
:
dist
.
all_reduce
(
tensor_to_allreduce
)
if
self
.
gradient_average
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
world_size
)
if
self
.
gradient_averag
e
:
if
self
.
allreduce_always_fp32
and
tensor
is
not
tensor_to_allreduc
e
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
world_size
)
tensor
.
copy_
(
tensor_to_allreduce
)
if
self
.
allreduce_always_fp32
and
tensor
is
not
tensor_to_allreduce
:
if
not
self
.
retain_allreduce_buffers
:
tensor
.
copy_
(
tensor_to_allreduce
)
if
multi_tensor_applier
.
available
:
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
tensor
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
tensor
,
bucket
)):
buf
.
copy_
(
synced
)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must
# be synced on bucket_stream anyway.
# Also, we maintain a live reference to the returned tensor in allreduce_buffers.
# But this doesn't hurt.
tensor
.
record_stream
(
bucket_stream
)
# torch.cuda.synchronize()
return
tensor
return
tensor
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
=-
1
):
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
,
force_default_stream
=
False
):
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
)
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
,
force_default_stream
)
if
self
.
retain_allreduce_buffers
:
if
self
.
retain_allreduce_buffers
:
if
self
.
allreduce_buffers
[
bucket_idx
]
is
not
None
:
if
self
.
allreduce_buffers
[
bucket_idx
]
is
not
None
:
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
...
@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
...
@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
grad
.
data
=
view
grad
.
data
=
view
else
:
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
if
multi_tensor_applier
.
available
:
# buf.copy_(synced)
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
allreduced
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
allreduced
,
bucket
)):
buf
.
copy_
(
synced
)
def
allreduce_fallback
(
self
):
def
allreduce_fallback
(
self
):
for
stream
,
event
in
zip
(
self
.
bucket_streams
,
self
.
bucket_events
):
stream
.
record_event
(
event
)
torch
.
cuda
.
current_stream
().
wait_event
(
event
)
if
self
.
retain_allreduce_buffers
:
if
self
.
retain_allreduce_buffers
:
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
else
:
else
:
...
@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
...
@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
for
i
,
bucket
in
enumerate
(
split_buckets
):
for
i
,
bucket
in
enumerate
(
split_buckets
):
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
)
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
,
force_default_stream
=
True
)
def
comm_ready_buckets
(
self
,
param
):
def
comm_ready_buckets
(
self
,
param
):
...
@@ -496,29 +525,24 @@ class DistributedDataParallel(Module):
...
@@ -496,29 +525,24 @@ class DistributedDataParallel(Module):
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
bucket_idx
==
self
.
next_bucket
:
if
bucket_idx
==
self
.
next_bucket
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
self
.
allreduce_maybe_retain
(
self
.
buckets
[
bucket_idx
],
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
self
.
next_bucket
+=
1
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
# Reversing upstream's logic here, because we constructed our buckets based on
self
.
allreduce_maybe_retain
(
self
.
buckets
[
bucket_idx
],
bucket_idx
)
# the order things were received during backward.
if
len
(
self
.
ready_buckets_not_reduced
)
>
0
:
self
.
next_bucket
+=
1
sorted_todo
=
sorted
(
self
.
ready_buckets_not_reduced
)
for
i
in
sorted_todo
:
# Reversing upstream's logic here, because we constructed our buckets based on
# Nothing can be reduced now
# the order things were received during backward.
if
i
>
self
.
next_bucket
:
if
len
(
self
.
ready_buckets_not_reduced
)
>
0
:
break
sorted_todo
=
sorted
(
self
.
ready_buckets_not_reduced
)
elif
i
==
self
.
next_bucket
:
for
i
in
sorted_todo
:
self
.
allreduce_maybe_retain
(
self
.
buckets
[
i
],
i
)
# Nothing can be reduced now
self
.
ready_buckets_not_reduced
.
remove
(
i
)
if
i
>
self
.
next_bucket
:
self
.
next_bucket
+=
1
break
else
:
elif
i
==
self
.
next_bucket
:
raise
ValueError
(
"i should always be >= next_bucket"
)
self
.
allreduce_maybe_retain
(
self
.
buckets
[
i
],
i
)
self
.
ready_buckets_not_reduced
.
remove
(
i
)
self
.
next_bucket
+=
1
else
:
raise
ValueError
(
"i should always be >= next_bucket"
)
else
:
else
:
self
.
ready_buckets_not_reduced
.
add
(
bucket_idx
)
self
.
ready_buckets_not_reduced
.
add
(
bucket_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment