Unverified Commit 7fcc8911 authored by hamlet's avatar hamlet Committed by GitHub
Browse files

Fix zero stage2 cpu_offload when some model trainable parameters skipped in training (#861)

* Fix zero stage2 cpu_offload when some model trainable parameters skipped in training, as in https://github.com/microsoft/DeepSpeed/issues/707



As some model trainable parameters skipped in training,
their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, 
so they have no norm_for_param_grads

* Trim space

* Trim space
Co-authored-by: default avatarOlatunji Ruwase <olruwase@microsoft.com>
parent 39013dd2
...@@ -883,8 +883,12 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -883,8 +883,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
for p in params: for p in params:
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
param_id = self.get_param_id(p) param_id = self.get_param_id(p)
param_norm = self.norm_for_param_grads[param_id] # as some model have trainable parameters but skipped in training,
total_norm += param_norm.item()**2 # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run,
# so they have no norm_for_param_grads
if param_id in self.norm_for_param_grads:
param_norm = self.norm_for_param_grads[param_id]
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment