Unverified Commit a75d971b authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

ZeRO Stage 2: Clear reduced gradients (#856)



* Ensure gradients of other partitions are cleared after reduction

* Remove redundant code
Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 46018859
......@@ -37,7 +37,7 @@ def split_half_float_double(tensors):
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t is not None and t.type() == dtype]
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
......@@ -477,6 +477,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear previously reduced grads of other partitions
self._clear_previous_reduced_grads()
if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):
......@@ -638,6 +640,9 @@ class FP16_DeepSpeedZeroOptimizer(object):
param.grad.data = new_grad_tensor.data.view_as(param.grad)
self.elements_in_ipg_bucket += param.numel()
assert param.grad is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"
self.grads_in_ipg_bucket.append(param.grad)
self.params_in_ipg_bucket.append((i, param, param_id))
......@@ -965,7 +970,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
if not self.is_param_in_current_partition[param_id]:
if self.overlap_comm and self.contiguous_gradients is False:
# Clear the previous grads during the next reduction
# Clear grads of other partitions during the next reduction
# to avoid clearing them before the reduction is complete.
if self.previous_reduced_grads is None:
self.previous_reduced_grads = []
......@@ -1078,16 +1083,18 @@ class FP16_DeepSpeedZeroOptimizer(object):
return tensor
#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
# previous_reduced_grads has the previous reduced grads,
# now it is safe to clear.
for param in self.previous_reduced_grads:
param.grad = None
self.previous_reduced_grads = None
#if rank is specified do a reduction instead of an allreduce
def allreduce_and_copy(self, small_bucket, rank=None, log=None):
if self.overlap_comm:
torch.cuda.synchronize()
# It is safe to clear the previously reduced grads of other partitions
self._clear_previous_reduced_grads()
stream = self.reduction_stream
else:
stream = torch.cuda.current_stream()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment