Commit 41c98511 authored by Michael Carilli's avatar Michael Carilli
Browse files

Removing gradient_average_split_factor

parent 4a9c2a53
...@@ -168,8 +168,7 @@ class DistributedDataParallel(Module): ...@@ -168,8 +168,7 @@ class DistributedDataParallel(Module):
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0):
gradient_average_split_factor=None):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
...@@ -190,10 +189,6 @@ class DistributedDataParallel(Module): ...@@ -190,10 +189,6 @@ class DistributedDataParallel(Module):
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
if gradient_average_split_factor is not None:
print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.")
self.gradient_predivide_factor = gradient_average_split_factor
self.world_size = float(dist.get_world_size()) self.world_size = float(dist.get_world_size())
self.retain_allreduce_buffers = retain_allreduce_buffers self.retain_allreduce_buffers = retain_allreduce_buffers
...@@ -394,7 +389,8 @@ class DistributedDataParallel(Module): ...@@ -394,7 +389,8 @@ class DistributedDataParallel(Module):
dist.all_reduce(tensor_to_allreduce) dist.all_reduce(tensor_to_allreduce)
if self.gradient_average: if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) if self.gradient_predivide_factor != self.world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce) tensor.copy_(tensor_to_allreduce)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment