Commit efa3cbcf authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

partially cleaned optimizer.py.

parent e6120623
......@@ -31,12 +31,6 @@ from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
DEBUG_ITERATION = 1 # 10
# <<<
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
......@@ -110,12 +104,11 @@ class MegatronOptimizer(ABC):
return mpu.get_model_parallel_group()
def clip_grad_norm(self, clip_grad, ITERATION):
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(
params, clip_grad,
model_parallel_group=self.get_model_parallel_group(),
ITERATION = ITERATION)
model_parallel_group=self.get_model_parallel_group())
def count_zeros(self):
......@@ -187,7 +180,7 @@ class MegatronOptimizer(ABC):
def step(self, args, timers):
pass
def gather_model_params(self, args, timers, ITERATION):
def gather_model_params(self, args, timers):
'''For the case of a non-distributed-optimizer, there is nothing to
do here.'''
pass
......@@ -239,9 +232,6 @@ class MegatronOptimizer(ABC):
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
def allreduce_embedding_grads(self, args):
# >>>
# return # ** .. TEMPORARY .. **
# <<<
self.allreduce_word_embedding_grads(args)
self.allreduce_position_embedding_grads(args)
......@@ -260,7 +250,6 @@ class MegatronOptimizer(ABC):
timers('backward-embedding-all-reduce').stop()
# class BaseFloat16Optimizer(MegatronOptimizer):
class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
......@@ -275,6 +264,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self.bf16 = bf16
self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16.
if self.grad_scaler is None:
assert self.bf16, 'fp16 expects a grad scaler.'
......@@ -313,7 +303,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Collect main grads.
main_grads = self._collect_main_grad_data_for_unscaling()
# pax(1, {"main_grads": main_grads})
# Reset found inf.
self.found_inf.fill_(0.0)
......@@ -330,25 +319,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Check for nan.
found_inf_flag = (self.found_inf.item() > 0)
# >>>
# if self.grad_scaler.scale <= 131072:
# pax(0, {
# # "grad_scaler" : self.grad_scaler,
# # "found_inf_flag" : found_inf_flag,
# "model_params" : [
# p
# for m in self.models
# for p in m.parameters()
# ],
# "model_grads" : [
# p.main_grad
# for m in self.models
# for p in m.parameters()
# ],
# # "main_grads" : main_grads,
# })
# <<<
return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
......@@ -409,16 +379,11 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
def step(self, args, timers, ITERATION):
# >>>
# self.debug_model(ITERATION, "before copy grad.", 0)
# self.debug_main(ITERATION, "before copy grad.", 0)
# <<<
def step(self, args, timers):
# Copy gradients from model params to main params.
timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_main_grads(ITERATION)
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# Do unscale, check for inf, and update grad scaler only for
......@@ -430,10 +395,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop()
# >>>
# <<<
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
......@@ -446,7 +407,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers('optimizer-clip-main-grad').start()
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, ITERATION)
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
......@@ -458,20 +419,13 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Update params from main params.
timers('optimizer-copy-main-to-model-params').start()
self._copy_main_params_to_model_params(ITERATION)
self._copy_main_params_to_model_params()
timers('optimizer-copy-main-to-model-params').stop()
# >>>
# self.debug_model(ITERATION, "after copy param.", 0)
# self.debug_main(ITERATION, "after copy param.", 0)
# <<<
# Successful update.
return True, grad_norm, num_zeros_in_grad
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
......@@ -613,7 +567,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
return model_data, main_data
def _copy_model_grads_to_main_grads(self, ITERATION):
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
......@@ -645,7 +599,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
model_param.main_grad = None
def _copy_main_params_to_model_params(self, ITERATION):
def _copy_main_params_to_model_params(self):
# Only needed for the float16 params.
model_data, main_data = self._get_model_and_main_params_data_float16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
......@@ -728,7 +682,7 @@ class FP32Optimizer(MegatronOptimizer):
@torch.no_grad()
def step(self, args, timers, ITERATION):
def step(self, args, timers):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
......@@ -747,7 +701,7 @@ class FP32Optimizer(MegatronOptimizer):
# Clip gradients.
grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad, ITERATION)
grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment