Commit 60704e72 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'log_grad_norm' into 'main'

added grad and params norm to logging and tensorboard

See merge request ADLR/megatron-lm!214
parents 577ad7d3 3dcbaec9
...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_realm_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -273,6 +274,15 @@ def _add_network_size_args(parser): ...@@ -273,6 +274,15 @@ def _add_network_size_args(parser):
return parser return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
return parser
def _add_regularization_args(parser): def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization') group = parser.add_argument_group(title='regularization')
......
...@@ -27,6 +27,12 @@ _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) ...@@ -27,6 +27,12 @@ _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support """Megatron specific extensions of torch Module with support
for pipelining.""" for pipelining."""
......
...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, ...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_stride': 1} 'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set. # Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
......
...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grads_for_norm = [] grads_for_norm = []
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param.tensor_model_parallel or \ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach() grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
......
...@@ -70,7 +70,7 @@ class MegatronOptimizer(ABC): ...@@ -70,7 +70,7 @@ class MegatronOptimizer(ABC):
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
for param in param_group['params']: for param in param_group['params']:
params.append(param) params.append(param)
clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
@abstractmethod @abstractmethod
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
...@@ -311,11 +311,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -311,11 +311,13 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
# If we found inf/nan, skip the update. # If we found inf/nan, skip the update.
if found_inf_flag: if found_inf_flag:
return False return False, None
# Clip the main gradients. # Clip the main gradients.
timers('optimizer-clip-main-grad').start() timers('optimizer-clip-main-grad').start()
self.clip_grad_norm(self.clip_grad) grad_norm = None
if self.clip_grad > 0.0:
grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop() timers('optimizer-clip-main-grad').stop()
# Step the optimizer. # Step the optimizer.
...@@ -327,7 +329,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -327,7 +329,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers('optimizer-copy-main-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# Successful update. # Successful update.
return True return True, grad_norm
def state_dict(self): def state_dict(self):
...@@ -392,14 +394,15 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -392,14 +394,15 @@ class FP32Optimizer(MegatronOptimizer):
Always return successful since there is no overflow.""" Always return successful since there is no overflow."""
# Clip gradients. # Clip gradients.
grad_norm = None
if self.clip_grad > 0.0: if self.clip_grad > 0.0:
self.clip_grad_norm(self.clip_grad) grad_norm = self.clip_grad_norm(self.clip_grad)
# Update parameters. # Update parameters.
self.optimizer.step() self.optimizer.step()
# No overflow for FP32 optimizer. # No overflow for FP32 optimizer.
return True return True, grad_norm
def reload_model_params(self): def reload_model_params(self):
......
...@@ -47,6 +47,7 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -47,6 +47,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -620,7 +621,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -620,7 +621,7 @@ def train_step(forward_step_func, data_iterator,
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successfull = optimizer.step() update_successfull, grad_norm = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
# Update learning rate. # Update learning rate.
...@@ -639,12 +640,13 @@ def train_step(forward_step_func, data_iterator, ...@@ -639,12 +640,13 @@ def train_step(forward_step_func, data_iterator,
for key in losses_reduced[0]: for key in losses_reduced[0]:
losses_reduced_for_key = [x[key] for x in losses_reduced] losses_reduced_for_key = [x[key] for x in losses_reduced]
loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
return loss_reduced, skipped_iter return loss_reduced, skipped_iter, grad_norm
return {}, skipped_iter return {}, skipped_iter, grad_norm
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter): loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm):
"""Log training information such as losses, timing, ....""" """Log training information such as losses, timing, ...."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -724,6 +726,14 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -724,6 +726,14 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale vs samples', loss_scale, writer.add_scalar('loss-scale vs samples', loss_scale,
args.consumed_train_samples) args.consumed_train_samples)
if grad_norm is not None:
writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations) normalizer=total_iterations)
...@@ -750,6 +760,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -750,6 +760,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' {}: {:.6E} |'.format(key, avg) log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = torch.cuda.FloatTensor([0.0]) total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format( log_string += ' number of nan iterations: {:3d} |'.format(
...@@ -802,11 +816,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -802,11 +816,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -814,10 +828,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -814,10 +828,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging. # Logging.
loss_scale = optimizer.get_loss_scale().item() loss_scale = optimizer.get_loss_scale().item()
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag, skipped_iter) report_memory_flag, skipped_iter,
grad_norm, params_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -19,11 +19,41 @@ import sys ...@@ -19,11 +19,41 @@ import sys
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
# Remove duplicate params.
params_data = []
for param in model.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses): def average_losses_across_data_parallel_group(losses):
......
...@@ -27,8 +27,9 @@ from megatron.training import evaluate_and_print_results ...@@ -27,8 +27,9 @@ from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
def process_batch(batch): def process_batch(batch):
...@@ -179,16 +180,22 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -179,16 +180,22 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
losses_dict, skipped_iter = train_step(forward_step, batch, model, losses_dict, skipped_iter, grad_norm = train_step(forward_step,
optimizer, lr_scheduler) batch, model,
optimizer,
lr_scheduler)
iteration += 1 iteration += 1
# Logging. # Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum, report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, iteration,
optimizer.get_loss_scale().item(), optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter) report_memory_flag, skipped_iter,
grad_norm, params_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment