Commit 7aad54f7 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating explanation for record_stream

parent 25ac9897
...@@ -458,10 +458,9 @@ class DistributedDataParallel(Module): ...@@ -458,10 +458,9 @@ class DistributedDataParallel(Module):
for buf, synced in zip(bucket, unflatten(tensor, bucket)): for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced) buf.copy_(synced)
# Any subsequent operations that we do on tensor after allreduce_bucket returns must # I think we actually do need this here. After allreduce_bucket returns, tensor will
# be synced on bucket_stream anyway. # eventually go out of scope and die, at which point it could otherwise be freed for
# Also, we maintain a live reference to the returned tensor in allreduce_buffers. # further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
# But this doesn't hurt.
tensor.record_stream(bucket_stream) tensor.record_stream(bucket_stream)
# torch.cuda.synchronize() # torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment