Unverified Commit a90a32d7 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Bug fix for sparse grads (#208)

parent 3ce531c9
...@@ -866,9 +866,8 @@ class DeepSpeedLight(Module): ...@@ -866,9 +866,8 @@ class DeepSpeedLight(Module):
for param_name, param in self.module.named_parameters(): for param_name, param in self.module.named_parameters():
if param.grad is not None: if param.grad is not None:
grad_data = param.grad.data grad_data = param.grad.data
param_name_root = param_name.split('.', 1)[0]
if self.sparse_gradients_enabled( if self.sparse_gradients_enabled(
) and param_name_root in self.csr_tensor_module_names: ) and param_name in self.csr_tensor_module_names:
grads.append(CSRTensor(grad_data)) grads.append(CSRTensor(grad_data))
else: else:
grads.append(grad_data) grads.append(grad_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment