Unverified Commit 6cb332f1 authored by Samyam Rajbhandari's avatar Samyam Rajbhandari Committed by GitHub
Browse files

CSR+FP32 fix (#206)

1) CSR parameter names should end with .weight. 
2) When using basic optimizer directly, DeepSpeed should handle zero_grad. Letting the basic optimizer do the zero_grad resulted in residual gradients in the embedding layer due to unknown reasons.
parent a0cd61e8
......@@ -170,7 +170,7 @@ class DeepSpeedLight(Module):
if self.sparse_gradients_enabled():
for name, module in self.module.named_modules():
if isinstance(module, torch.nn.Embedding):
self.csr_tensor_module_names.add(name)
self.csr_tensor_module_names.add(name + ".weight")
logging.info("Will convert {} to sparse (csr) "
"tensor during training".format(name))
......@@ -695,6 +695,13 @@ class DeepSpeedLight(Module):
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
def zero_grad(self):
"""
Zero parameter grads.
"""
for param_name, param in self.module.named_parameters():
param.grad = None
def step(self):
r"""Execute the weight update step after forward and backward propagation on effective_train_batch
"""
......@@ -708,6 +715,12 @@ class DeepSpeedLight(Module):
if self.is_gradient_accumulation_boundary():
self.optimizer.step()
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled():
self.zero_grad()
else:
self.optimizer.zero_grad()
# Check overlow here since in DS fp16 optimizer, the overflow is updated in above step() function.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment