Commit 00764415 authored by mohammad's avatar mohammad
Browse files

added grad norm to logging and tensorboard

parent 1aa2e08a
...@@ -70,7 +70,7 @@ class MegatronOptimizer(ABC): ...@@ -70,7 +70,7 @@ class MegatronOptimizer(ABC):
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
for param in param_group['params']: for param in param_group['params']:
params.append(param) params.append(param)
clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
@abstractmethod @abstractmethod
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
...@@ -311,11 +311,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -311,11 +311,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# If we found inf/nan, skip the update. # If we found inf/nan, skip the update.
if found_inf_flag: if found_inf_flag:
return False return False, None
# Clip the main gradients. # Clip the main gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad').start()
self.clip_grad_norm(self.clip_grad) grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# Step the optimizer. # Step the optimizer.
...@@ -327,7 +329,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -327,7 +329,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# Successful update. # Successful update.
return True return True, grad_norm
def state_dict(self): def state_dict(self):
...@@ -392,14 +394,15 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -392,14 +394,15 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
# Clip gradients. # Clip gradients.
grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
# Update parameters. # Update parameters.
self.optimizer.step() self.optimizer.step()
# No overflow for FP32 optimizer. # No overflow for FP32 optimizer.
return True return True, grad_norm
def reload_model_params(self): def reload_model_params(self):
......
...@@ -617,7 +617,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -617,7 +617,7 @@ def train_step(forward_step_func, data_iterator,
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successfull = optimizer.step() update_successfull, grad_norm = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
# Update learning rate. # Update learning rate.
...@@ -636,12 +636,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -636,12 +636,12 @@ def train_step(forward_step_func, data_iterator,
for key in losses_reduced[0]: for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced] losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter return loss_reduced, skipped_iter, grad_norm
return {}, skipped_iter return {}, skipped_iter, grad_norm
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter): loss_scale, report_memory_flag, skipped_iter, grad_norm):
"""Log training information such as losses, timing, ....""" """Log training information such as losses, timing, ...."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -721,6 +721,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -721,6 +721,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale, writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples) args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations) normalizer=total_iterations)
...@@ -747,6 +751,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -747,6 +751,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format( log_string += ' number of nan iterations: {:3d} |'.format(
...@@ -799,11 +805,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -799,11 +805,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -814,7 +820,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -814,7 +820,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag, skipped_iter) report_memory_flag, skipped_iter,
grad_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -179,8 +179,10 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -179,8 +179,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, skipped_iter = train_step(forward_step, batch, model, losses_dict, skipped_iter, grad_norm = train_step(forward_step,
optimizer, lr_scheduler) batch, model,
optimizer,
lr_scheduler)
iteration += 1 iteration += 1
# Logging. # Logging.
...@@ -188,7 +190,8 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -188,7 +190,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, iteration,
optimizer.get_loss_scale().item(), optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter) report_memory_flag, skipped_iter,
grad_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment