Commit efa3cbcf authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially cleaned optimizer.py.

parent e6120623
...@@ -31,12 +31,6 @@ from megatron.utils import unwrap_model ...@@ -31,12 +31,6 @@ from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 1 # 10
# <<<
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters. """Zero out the gradient for a group of parameters.
...@@ -110,12 +104,11 @@ class MegatronOptimizer(ABC): ...@@ -110,12 +104,11 @@ class MegatronOptimizer(ABC):
return mpu.get_model_parallel_group() return mpu.get_model_parallel_group()
def clip_grad_norm(self, clip_grad, ITERATION): def clip_grad_norm(self, clip_grad):
params = self.get_parameters() params = self.get_parameters()
return clip_grad_norm_fp32( return clip_grad_norm_fp32(
params, clip_grad, params, clip_grad,
model_parallel_group=self.get_model_parallel_group(), model_parallel_group=self.get_model_parallel_group())
ITERATION = ITERATION)
def count_zeros(self): def count_zeros(self):
...@@ -187,7 +180,7 @@ class MegatronOptimizer(ABC): ...@@ -187,7 +180,7 @@ class MegatronOptimizer(ABC):
def step(self, args, timers): def step(self, args, timers):
pass pass
def gather_model_params(self, args, timers, ITERATION): def gather_model_params(self, args, timers):
'''For the case of a non-distributed-optimizer, there is nothing to '''For the case of a non-distributed-optimizer, there is nothing to
do here.''' do here.'''
pass pass
...@@ -239,9 +232,6 @@ class MegatronOptimizer(ABC): ...@@ -239,9 +232,6 @@ class MegatronOptimizer(ABC):
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self, args): def allreduce_embedding_grads(self, args):
# >>>
# return # ** .. TEMPORARY .. **
# <<<
self.allreduce_word_embedding_grads(args) self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args) self.allreduce_position_embedding_grads(args)
...@@ -260,7 +250,6 @@ class MegatronOptimizer(ABC): ...@@ -260,7 +250,6 @@ class MegatronOptimizer(ABC):
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer): class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
...@@ -275,6 +264,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -275,6 +264,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16. # None grad scaler is only supported for bf16.
if self.grad_scaler is None: if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.' assert self.bf16, 'fp16 expects a grad scaler.'
...@@ -313,7 +303,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -313,7 +303,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Collect main grads. # Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling() main_grads = self._collect_main_grad_data_for_unscaling()
# pax(1, {"main_grads": main_grads})
# Reset found inf. # Reset found inf.
self.found_inf.fill_(0.0) self.found_inf.fill_(0.0)
...@@ -330,25 +319,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -330,25 +319,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Check for nan. # Check for nan.
found_inf_flag = (self.found_inf.item() > 0) found_inf_flag = (self.found_inf.item() > 0)
# >>>
# if self.grad_scaler.scale <= 131072:
# pax(0, {
# # "grad_scaler" : self.grad_scaler,
# # "found_inf_flag" : found_inf_flag,
# "model_params" : [
# p
# for m in self.models
# for p in m.parameters()
# ],
# "model_grads" : [
# p.main_grad
# for m in self.models
# for p in m.parameters()
# ],
# # "main_grads" : main_grads,
# })
# <<<
return found_inf_flag return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
...@@ -409,16 +379,11 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -409,16 +379,11 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad() @torch.no_grad()
def step(self, args, timers, ITERATION): def step(self, args, timers):
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_main(ITERATION, "before copy grad.", 0)
# <<<
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start() timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads(ITERATION) self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for # Do unscale, check for inf, and update grad scaler only for
...@@ -430,10 +395,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -430,10 +395,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
found_inf_flag = self._unscale_main_grads_and_check_for_nan() found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop() timers('optimizer-unscale-and-check-inf').stop()
# >>>
# <<<
# We are done with scaling gradients # We are done with scaling gradients
# so we can update the loss scale. # so we can update the loss scale.
self.grad_scaler.update(found_inf_flag) self.grad_scaler.update(found_inf_flag)
...@@ -446,7 +407,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -446,7 +407,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad').start()
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, ITERATION) grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads # count the zeros in the grads
...@@ -458,20 +419,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -458,20 +419,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Update params from main params. # Update params from main params.
timers('optimizer-copy-main-to-model-params').start() timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION) self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# >>>
# self.debug_model(ITERATION, "after copy param.", 0)
# self.debug_main(ITERATION, "after copy param.", 0)
# <<<
# Successful update. # Successful update.
return True, grad_norm, num_zeros_in_grad return True, grad_norm, num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
...@@ -613,7 +567,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -613,7 +567,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
return model_data, main_data return model_data, main_data
def _copy_model_grads_to_main_grads(self, ITERATION): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups): self.fp32_from_float16_groups):
...@@ -645,7 +599,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -645,7 +599,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
model_param.main_grad = None model_param.main_grad = None
def _copy_main_params_to_model_params(self, ITERATION): def _copy_main_params_to_model_params(self):
# Only needed for the float16 params. # Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16() model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
...@@ -728,7 +682,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -728,7 +682,7 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad() @torch.no_grad()
def step(self, args, timers, ITERATION): def step(self, args, timers):
"""Clip gradients (if needed) and step the base optimizer. """Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
...@@ -747,7 +701,7 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -747,7 +701,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients. # Clip gradients.
grad_norm = None grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, ITERATION) grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads # count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \ num_zeros_in_grad = self.count_zeros() if \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment