Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
25ac9897
"...text-generation-inference.git" did not exist on "32a253063dae768e71a0b0aa099cfbbe962032d1"
Commit
25ac9897
authored
Apr 23, 2019
by
Michael Carilli
Browse files
Moving flat allreduce buffer creation to main stream
parent
b8965a78
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
50 deletions
+74
-50
apex/parallel/distributed.py
apex/parallel/distributed.py
+74
-50
No files found.
apex/parallel/distributed.py
View file @
25ac9897
...
@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
...
@@ -222,6 +222,8 @@ class DistributedDataParallel(Module):
self
.
delay_allreduce
=
delay_allreduce
self
.
delay_allreduce
=
delay_allreduce
self
.
message_size
=
message_size
self
.
message_size
=
message_size
self
.
main_stream
=
torch
.
cuda
.
current_stream
()
self
.
bucket_streams
=
[]
self
.
bucket_streams
=
[]
self
.
bucket_events
=
[]
self
.
bucket_events
=
[]
...
@@ -411,33 +413,64 @@ class DistributedDataParallel(Module):
...
@@ -411,33 +413,64 @@ class DistributedDataParallel(Module):
return
self
.
bucket_events
[
0
]
return
self
.
bucket_events
[
0
]
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
):
def
allreduce_bucket
(
self
,
bucket
,
bucket_idx
,
force_default_stream
):
tensor
=
flatten
(
bucket
)
tensor
=
flatten
(
bucket
)
tensor_to_allreduce
=
tensor
if
force_default_stream
:
bucket_stream
=
self
.
main_stream
else
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
if
self
.
allreduce_always_fp32
:
tensor_to_allreduce
=
tensor
tensor_to_allreduce
=
tensor
.
float
()
if
self
.
gradient_predivide_factor
!=
1.0
:
if
self
.
allreduce_always_fp32
:
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
tensor_to_allreduce
=
tensor
.
float
(
)
if
self
.
allreduce_different_streams
and
self
.
bucket_pgs
:
if
self
.
gradient_predivide_factor
!=
1.0
:
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
bucket_pgs
[
bucket_idx
])
tensor_to_allreduce
.
mul_
(
1.
/
self
.
gradient_predivide_factor
)
else
:
dist
.
all_reduce
(
tensor_to_allreduce
)
if
self
.
allreduce_different_streams
and
self
.
bucket_pgs
:
dist
.
all_reduce
(
tensor_to_allreduce
,
group
=
self
.
bucket_pgs
[
bucket_idx
])
else
:
dist
.
all_reduce
(
tensor_to_allreduce
)
if
self
.
gradient_average
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
world_size
)
if
self
.
gradient_averag
e
:
if
self
.
allreduce_always_fp32
and
tensor
is
not
tensor_to_allreduc
e
:
tensor_to_allreduce
.
mul_
(
self
.
gradient_predivide_factor
/
self
.
world_size
)
tensor
.
copy_
(
tensor_to_allreduce
)
if
self
.
allreduce_always_fp32
and
tensor
is
not
tensor_to_allreduce
:
if
not
self
.
retain_allreduce_buffers
:
tensor
.
copy_
(
tensor_to_allreduce
)
if
multi_tensor_applier
.
available
:
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
tensor
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
tensor
,
bucket
)):
buf
.
copy_
(
synced
)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must
# be synced on bucket_stream anyway.
# Also, we maintain a live reference to the returned tensor in allreduce_buffers.
# But this doesn't hurt.
tensor
.
record_stream
(
bucket_stream
)
# torch.cuda.synchronize()
return
tensor
return
tensor
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
=-
1
):
def
allreduce_maybe_retain
(
self
,
bucket
,
bucket_idx
,
force_default_stream
=
False
):
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
)
allreduced
=
self
.
allreduce_bucket
(
bucket
,
bucket_idx
,
force_default_stream
)
if
self
.
retain_allreduce_buffers
:
if
self
.
retain_allreduce_buffers
:
if
self
.
allreduce_buffers
[
bucket_idx
]
is
not
None
:
if
self
.
allreduce_buffers
[
bucket_idx
]
is
not
None
:
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
raise
RuntimeError
(
"The backward pass is attempting to replace an already-filled "
...
@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
...
@@ -445,19 +478,15 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
self
.
allreduce_buffers
[
bucket_idx
]
=
allreduced
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
for
view
,
grad
in
zip
(
unflatten
(
allreduced
,
bucket
),
bucket
):
grad
.
data
=
view
grad
.
data
=
view
else
:
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
if
multi_tensor_applier
.
available
:
# buf.copy_(synced)
multi_tensor_applier
(
self
.
multi_tensor_scale
,
self
.
_overflow_buf
,
[
unflatten
(
allreduced
,
bucket
),
bucket
],
1.0
)
else
:
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
allreduced
,
bucket
)):
buf
.
copy_
(
synced
)
def
allreduce_fallback
(
self
):
def
allreduce_fallback
(
self
):
for
stream
,
event
in
zip
(
self
.
bucket_streams
,
self
.
bucket_events
):
stream
.
record_event
(
event
)
torch
.
cuda
.
current_stream
().
wait_event
(
event
)
if
self
.
retain_allreduce_buffers
:
if
self
.
retain_allreduce_buffers
:
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
grads
=
[
param
.
grad
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
else
:
else
:
...
@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
...
@@ -472,7 +501,7 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
len
(
split_buckets
))]
for
i
,
bucket
in
enumerate
(
split_buckets
):
for
i
,
bucket
in
enumerate
(
split_buckets
):
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
)
allreduced
=
self
.
allreduce_maybe_retain
(
bucket
,
i
,
force_default_stream
=
True
)
def
comm_ready_buckets
(
self
,
param
):
def
comm_ready_buckets
(
self
,
param
):
...
@@ -496,29 +525,24 @@ class DistributedDataParallel(Module):
...
@@ -496,29 +525,24 @@ class DistributedDataParallel(Module):
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
bucket_idx
==
self
.
next_bucket
:
if
bucket_idx
==
self
.
next_bucket
:
bucket_stream
=
self
.
_stream_this_bucket
(
bucket_idx
)
self
.
allreduce_maybe_retain
(
self
.
buckets
[
bucket_idx
],
bucket_idx
)
bucket_event
=
self
.
_event_this_bucket
(
bucket_idx
)
torch
.
cuda
.
current_stream
().
record_event
(
bucket_event
)
self
.
next_bucket
+=
1
bucket_stream
.
wait_event
(
bucket_event
)
with
torch
.
cuda
.
stream
(
bucket_stream
):
# Reversing upstream's logic here, because we constructed our buckets based on
self
.
allreduce_maybe_retain
(
self
.
buckets
[
bucket_idx
],
bucket_idx
)
# the order things were received during backward.
if
len
(
self
.
ready_buckets_not_reduced
)
>
0
:
self
.
next_bucket
+=
1
sorted_todo
=
sorted
(
self
.
ready_buckets_not_reduced
)
for
i
in
sorted_todo
:
# Reversing upstream's logic here, because we constructed our buckets based on
# Nothing can be reduced now
# the order things were received during backward.
if
i
>
self
.
next_bucket
:
if
len
(
self
.
ready_buckets_not_reduced
)
>
0
:
break
sorted_todo
=
sorted
(
self
.
ready_buckets_not_reduced
)
elif
i
==
self
.
next_bucket
:
for
i
in
sorted_todo
:
self
.
allreduce_maybe_retain
(
self
.
buckets
[
i
],
i
)
# Nothing can be reduced now
self
.
ready_buckets_not_reduced
.
remove
(
i
)
if
i
>
self
.
next_bucket
:
self
.
next_bucket
+=
1
break
else
:
elif
i
==
self
.
next_bucket
:
raise
ValueError
(
"i should always be >= next_bucket"
)
self
.
allreduce_maybe_retain
(
self
.
buckets
[
i
],
i
)
self
.
ready_buckets_not_reduced
.
remove
(
i
)
self
.
next_bucket
+=
1
else
:
raise
ValueError
(
"i should always be >= next_bucket"
)
else
:
else
:
self
.
ready_buckets_not_reduced
.
add
(
bucket_idx
)
self
.
ready_buckets_not_reduced
.
add
(
bucket_idx
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment