Commit 3202f237 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/zerograd-fix' into 'main'

fixed zero_grad for fp32_from_float16_groups

See merge request ADLR/megatron-lm!283
parents c1075275 4e64903d
...@@ -282,9 +282,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -282,9 +282,14 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups.""" float16_groups & fp32_from_fp32_groups. We additionally zero
fp32_from_float16_groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for group in self.float16_groups: for group in self.float16_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_float16_groups:
_zero_grad_group_helper(group, set_to_none)
for group in self.fp32_from_fp32_groups: for group in self.fp32_from_fp32_groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment