Commit 57d1356e authored by mohammad's avatar mohammad
Browse files

added params norms

parent d31d7e49
...@@ -27,6 +27,12 @@ _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) ...@@ -27,6 +27,12 @@ _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support """Megatron specific extensions of torch Module with support
for pipelining.""" for pipelining."""
......
...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, ...@@ -43,6 +43,12 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_stride': 1} 'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set. # Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
......
...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier ...@@ -22,6 +22,8 @@ from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -54,9 +56,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
grads_for_norm = [] grads_for_norm = []
for param in parameters: for param in parameters:
grad_not_none = param.grad is not None grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param.tensor_model_parallel or \ is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
(mpu.get_tensor_model_parallel_rank() == 0)
grad = param.grad.detach() grad = param.grad.detach()
if grad_not_none: if grad_not_none:
# Make sure the grads are in fp32 # Make sure the grads are in fp32
......
...@@ -47,6 +47,7 @@ from megatron.model import DistributedDataParallel as LocalDDP ...@@ -47,6 +47,7 @@ from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -641,7 +642,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -641,7 +642,8 @@ def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration, def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, skipped_iter, grad_norm): loss_scale, report_memory_flag, skipped_iter,
grad_norm, params_norm):
"""Log training information such as losses, timing, ....""" """Log training information such as losses, timing, ...."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -725,6 +727,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -725,6 +727,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
writer.add_scalar('grad-norm', grad_norm, iteration) writer.add_scalar('grad-norm', grad_norm, iteration)
writer.add_scalar('grad-norm vs samples', grad_norm, writer.add_scalar('grad-norm vs samples', grad_norm,
args.consumed_train_samples) args.consumed_train_samples)
if params_norm is not None:
writer.add_scalar('params-norm', params_norm, iteration)
writer.add_scalar('params-norm vs samples', params_norm,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations) normalizer=total_iterations)
...@@ -753,6 +759,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -753,6 +759,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
log_string += ' loss scale: {:.1f} |'.format(loss_scale) log_string += ' loss scale: {:.1f} |'.format(loss_scale)
if grad_norm is not None: if grad_norm is not None:
log_string += ' grad norm: {:.3f} |'.format(grad_norm) log_string += ' grad norm: {:.3f} |'.format(grad_norm)
if params_norm is not None:
log_string += ' params norm: {:.3f} |'.format(params_norm)
log_string += ' number of skipped iterations: {:3d} |'.format( log_string += ' number of skipped iterations: {:3d} |'.format(
total_loss_dict[skipped_iters_key]) total_loss_dict[skipped_iters_key])
log_string += ' number of nan iterations: {:3d} |'.format( log_string += ' number of nan iterations: {:3d} |'.format(
...@@ -817,11 +825,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -817,11 +825,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Logging. # Logging.
loss_scale = optimizer.get_loss_scale().item() loss_scale = optimizer.get_loss_scale().item()
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, loss_scale, iteration, loss_scale,
report_memory_flag, skipped_iter, report_memory_flag, skipped_iter,
grad_norm) grad_norm, params_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
...@@ -19,11 +19,41 @@ import sys ...@@ -19,11 +19,41 @@ import sys
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
# Remove duplicate params.
params_data = []
for param in model.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses): def average_losses_across_data_parallel_group(losses):
......
...@@ -27,8 +27,9 @@ from megatron.training import evaluate_and_print_results ...@@ -27,8 +27,9 @@ from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer from megatron.training import setup_model_and_optimizer
from megatron.training import train_step from megatron.training import train_step
from megatron.training import training_log from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
def process_batch(batch): def process_batch(batch):
...@@ -186,12 +187,13 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -186,12 +187,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
iteration += 1 iteration += 1
# Logging. # Logging.
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum, report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'], optimizer.param_groups[0]['lr'],
iteration, iteration,
optimizer.get_loss_scale().item(), optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter, report_memory_flag, skipped_iter,
grad_norm) grad_norm, params_norm)
# Autoresume # Autoresume
if args.adlr_autoresume and \ if args.adlr_autoresume and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment