Commit b8965a78 authored by Michael Carilli's avatar Michael Carilli
Browse files

Option to elide unflattening copy

parent 887a50bd
...@@ -443,6 +443,8 @@ class DistributedDataParallel(Module): ...@@ -443,6 +443,8 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
for view, grad in zip(unflatten(allreduced, bucket), bucket):
grad.data = view
else: else:
if multi_tensor_applier.available: if multi_tensor_applier.available:
multi_tensor_applier( multi_tensor_applier(
...@@ -456,7 +458,10 @@ class DistributedDataParallel(Module): ...@@ -456,7 +458,10 @@ class DistributedDataParallel(Module):
def allreduce_fallback(self): def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
...@@ -482,7 +487,11 @@ class DistributedDataParallel(Module): ...@@ -482,7 +487,11 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.") "bucket slot. This is almost certainly an error.")
self.buckets[bucket_idx][bucket_loc] = param.grad.data if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1 self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment