Commit 2c3cb9fc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

many edits; working towards first draft.

parent 867105c2
This diff is collapsed.
...@@ -323,6 +323,22 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -323,6 +323,22 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads. # Collect main grads.
...@@ -552,18 +568,20 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -552,18 +568,20 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True): # >>>
"""We only need to zero the model related parameters, i.e., # def zero_grad(self, set_to_none=True):
float16_groups & fp32_from_fp32_groups. We additionally zero # """We only need to zero the model related parameters, i.e.,
fp32_from_float16_groups as a memory optimization to reduce # float16_groups & fp32_from_fp32_groups. We additionally zero
fragmentation; in the case of set_to_none==True, the space # fp32_from_float16_groups as a memory optimization to reduce
used by this field can be safely deallocated at this point.""" # fragmentation; in the case of set_to_none==True, the space
for group in self.float16_groups: # used by this field can be safely deallocated at this point."""
_zero_grad_group_helper(group, set_to_none) # for group in self.float16_groups:
for group in self.fp32_from_float16_groups: # _zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(group, set_to_none) # for group in self.fp32_from_float16_groups:
for group in self.fp32_from_fp32_groups: # _zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(group, set_to_none) # for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment