Commit 87b8b9dc authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'rc-debug-underflow' into 'main'

Rc debug underflow

See merge request ADLR/megatron-lm!246
parents 83d26f03 57437cb1
......@@ -308,6 +308,8 @@ def _add_logging_args(parser):
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
......
......@@ -84,7 +84,7 @@ def get_megatron_optimizer(model):
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad)
args.clip_grad, args.log_num_zeros_in_grad)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad)
return FP32Optimizer(optimizer, args.clip_grad, args.log_num_zeros_in_grad)
......@@ -118,3 +118,31 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
clip_coeff)
return total_norm
def count_zeros_fp32(parameters):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros = 0.0
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grad = param.grad.detach()
num_zeros = grad.numel() - torch.count_nonzero(grad)
total_num_zeros = num_zeros + total_num_zeros
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(total_num_zeros,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_num_zeros = total_num_zeros.item()
return total_num_zeros
......@@ -27,7 +27,7 @@ from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
def _zero_grad_group_helper(group, set_to_none):
......@@ -65,13 +65,21 @@ class MegatronOptimizer(ABC):
self.optimizer = optimizer
assert self.optimizer, 'no optimizer is provided.'
def clip_grad_norm(self, clip_grad):
def get_parameters(self):
params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return params
def clip_grad_norm(self, clip_grad):
params = self.get_parameters()
return clip_grad_norm_fp32(params, clip_grad)
def count_zeros(self):
params = self.get_parameters()
return count_zeros_fp32(params)
@abstractmethod
def zero_grad(self, set_to_none=True):
pass
......@@ -131,11 +139,12 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer):
def __init__(self, optimizer, grad_scaler, clip_grad):
def __init__(self, optimizer, grad_scaler, clip_grad, log_num_zeros_in_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer)
self.grad_scaler = grad_scaler
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
......@@ -289,7 +298,6 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad()
def step(self):
......@@ -311,7 +319,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# If we found inf/nan, skip the update.
if found_inf_flag:
return False, None
return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
......@@ -320,6 +328,9 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
# Step the optimizer.
self.optimizer.step()
......@@ -329,7 +340,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop()
# Successful update.
return True, grad_norm
return True, grad_norm, num_zeros_in_grad
def state_dict(self):
......@@ -370,10 +381,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad):
super(FP32Optimizer, self).__init__(optimizer)
self.clip_grad = clip_grad
self.log_num_zeros_in_grad = log_num_zeros_in_grad
self._scale = torch.cuda.FloatTensor([1.0])
......@@ -398,11 +410,14 @@ class FP32Optimizer(MegatronOptimizer):
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
# Update parameters.
self.optimizer.step()
# No overflow for FP32 optimizer.
return True, grad_norm
return True, grad_norm, num_zeros_in_grad
def reload_model_params(self):
......
......@@ -378,7 +378,7 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm = optimizer.step()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# Update learning rate.
......@@ -397,13 +397,13 @@ def train_step(forward_step_func, data_iterator,
for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter, grad_norm
return {}, skipped_iter, grad_norm
return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
return {}, skipped_iter, grad_norm, num_zeros_in_grad
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm):
grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
......@@ -492,6 +492,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if num_zeros_in_grad is not None:
writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
......@@ -526,6 +530,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if num_zeros_in_grad is not None:
log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format(
......@@ -581,11 +587,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
......@@ -600,7 +607,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm)
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment