Commit 866454e0 authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Distributed fix for non-all reduce call.

parent 2fa4dbaf
......@@ -26,7 +26,9 @@ def flat_dist_call(tensors, call, extra_args=None):
call(coalesced, *extra_args)
else:
call(coalesced)
coalesced /= dist.get_world_size()
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment