Commit e724785f authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vision-merge' into 'main'

second phase of vision code merge

See merge request ADLR/megatron-lm!381
parents e156d2fe 2b628f96
...@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm. # Persistent fused layer norm.
...@@ -395,6 +404,13 @@ def _add_regularization_args(parser): ...@@ -395,6 +404,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.') help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0, group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.') help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9, group.add_argument('--adam-beta1', type=float, default=0.9,
...@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser): ...@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0, group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler' help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.') 'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true', group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,' help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum ' 'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input ' 'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note' 'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.') 'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true', group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler ' help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning ' '(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style ' 'rate, maximum number of iterations, and decay style '
......
...@@ -167,7 +167,7 @@ def get_rng_state(): ...@@ -167,7 +167,7 @@ def get_rng_state():
return rng_state_list return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if not args.no_save_optim: if not args.no_save_optim:
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None: if opt_param_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict() state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
...@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version): ...@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for" print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version)) " checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of :attr:`state_dict` of the checkpoint match the names of
...@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None: if opt_param_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler']) if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
......
...@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
""" """
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder, layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
...@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type) attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
...@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = hidden_states residual = hidden_states
# jit scripting for a nn.module (with dropout) is not if self.drop_path is None:
# trigerring the fusion kernel. For now, we use two # jit scripting for a nn.module (with dropout) is not
# different nn.functional routines to account for varying # trigerring the fusion kernel. For now, we use two
# dropout semantics during training and inference phases. # different nn.functional routines to account for varying
if self.bias_dropout_fusion: # dropout semantics during training and inference phases.
if self.training: if self.bias_dropout_fusion:
bias_dropout_add_func = bias_dropout_add_fused_train if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else: else:
bias_dropout_add_func = bias_dropout_add_fused_inference bias_dropout_add_func = get_bias_dropout_add(self.training)
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias.expand_as(residual),
residual, residual,
self.hidden_dropout) self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
...@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
# re-enable torch grad to enable fused optimization. if self.drop_path is None:
with torch.enable_grad(): # re-enable torch grad to enable fused optimization.
output = bias_dropout_add_func( with torch.enable_grad():
mlp_output, output = bias_dropout_add_func(
mlp_bias.expand_as(residual), mlp_output,
residual, mlp_bias.expand_as(residual),
self.hidden_dropout) residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output return output
...@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule): ...@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True): pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule): ...@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
...@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule): ...@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
...@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule): ...@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
......
...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def get_param_groups(modules,
"""Divide params into with-weight-decay and without-weight-decay groups. no_weight_decay_cond,
Layernorms and baises will have no weight decay but the rest will. scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
""" """
wd_no_scale_lr = []
weight_decay_params = {'params': []} wd_scale_lr = []
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules: for module in modules:
for module_ in module.modules(): for name, param in module.named_parameters():
if isinstance(module_, LayerNorm): if not param.requires_grad:
no_weight_decay_params['params'].extend( continue
[p for p in list(module_._parameters.values())
if p is not None and p.requires_grad]) if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else: else:
weight_decay_params['params'].extend( # do not regularize biases nor Norm parameters
[p for n, p in list(module_._parameters.items()) no_wd = name.endswith(".bias") or len(param.shape) == 1
if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n == 'bias'])
return weight_decay_params, no_weight_decay_params if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model): param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -13,19 +13,20 @@ ...@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Learning rate decay functions.""" """Learning rate decay and weight decay incr functions."""
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
class AnnealingLR(object): class OptimizerParamScheduler(object):
"""Anneals the learning rate.""" """Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, lr_warmup_steps, lr_decay_steps, lr_decay_style,
use_checkpoint_lr_scheduler=True, start_wd, end_wd, wd_incr_steps, wd_incr_style,
override_lr_scheduler=False): use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values. # Class values.
self.optimizer = optimizer self.optimizer = optimizer
...@@ -35,24 +36,55 @@ class AnnealingLR(object): ...@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert self.min_lr >= 0.0 assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0 self.num_steps = 0
self.decay_steps = decay_steps self.lr_decay_steps = lr_decay_steps
assert self.decay_steps > 0 assert self.lr_decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.lr_warmup_steps < self.lr_decay_steps
self.decay_style = decay_style self.lr_decay_style = lr_decay_style
self.override_lr_scheduler = override_lr_scheduler self.start_wd = start_wd
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.end_wd = end_wd
if self.override_lr_scheduler: assert self.start_wd >= 0.0
assert not self.use_checkpoint_lr_scheduler, 'both override and '\ assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(0) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self):
""" Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = incr_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self): def get_lr(self):
...@@ -60,33 +92,33 @@ class AnnealingLR(object): ...@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps) float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value. # If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant': if self.lr_decay_style == 'constant':
return self.max_lr return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`. # For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps: if self.num_steps > self.lr_decay_steps:
return self.min_lr return self.min_lr
# If we are done with the warmup period, use the decay style. # If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_) decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0 assert decay_ratio >= 0.0
assert decay_ratio <= 1.0 assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear': if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio) coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine': elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else: else:
raise Exception('{} decay style is not supported.'.format( raise Exception('{} decay style is not supported.'.format(
self.decay_style)) self.lr_decay_style))
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
...@@ -95,18 +127,24 @@ class AnnealingLR(object): ...@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
self.num_steps += increment self.num_steps += increment
new_lr = self.get_lr() new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'max_lr': self.max_lr, 'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps, 'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps, 'num_steps': self.num_steps,
'decay_style': self.decay_style, 'lr_decay_style': self.lr_decay_style,
'decay_steps': self.decay_steps, 'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr 'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
} }
return state_dict return state_dict
...@@ -114,13 +152,13 @@ class AnnealingLR(object): ...@@ -114,13 +152,13 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
if self.override_lr_scheduler: if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value)) print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value return cls_value
if not self.use_checkpoint_lr_scheduler: if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \ assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \ f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match' f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value, print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name)) name))
...@@ -140,25 +178,57 @@ class AnnealingLR(object): ...@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate') 'minimum learning rate')
if 'warmup_iter' in sd: if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter'] lr_warmup_steps_ = sd['warmup_iter']
elif 'warmup_steps' in sd:
lr_warmup_steps_ = sd['warmup_steps']
else: else:
warmup_steps_ = sd['warmup_steps'] lr_warmup_steps_ = sd['lr_warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps, self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
warmup_steps_, lr_warmup_steps_,
'warmup iterations') 'warmup iterations')
if 'end_iter' in sd: if 'end_iter' in sd:
decay_steps_ = sd['end_iter'] lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else: else:
decay_steps_ = sd['decay_steps'] lr_decay_steps_ = sd['lr_decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations') 'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'], if 'decay_style' in sd:
'decay style') lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(increment=num_steps) self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
...@@ -98,7 +98,12 @@ def custom_backward(output, grad_output): ...@@ -98,7 +98,12 @@ def custom_backward(output, grad_output):
) )
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor) if not collect_non_loss_data:
loss, loss_reduced = output_tensor output_tensor = loss_func(output_tensor)
output_tensor = loss / get_num_microbatches() loss, loss_reduced = output_tensor
losses_reduced.append(loss_reduced) output_tensor = loss / get_num_microbatches()
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop() timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
...@@ -206,8 +216,12 @@ def dummy_handler(): ...@@ -206,8 +216,12 @@ def dummy_handler():
pass pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
context_handler = model.no_sync context_handler = model.no_sync
losses_reduced = [] forward_data_store = []
input_tensor, output_tensor_grad = None, None input_tensor, output_tensor_grad = None, None
with context_handler(): with context_handler():
for i in range(get_num_microbatches() - 1): for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, def forward_backward_pipelining_with_interleaving(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
losses_reduced = [] forward_data_store = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
...@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func, output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id], data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, losses_reduced) input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers)) timers=timers))
return losses_reduced return forward_data_store
def get_tensor_shapes(rank, model_type): def get_tensor_shapes(rank, model_type):
...@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func,
model, optimizer, timers, data_iterator,
forward_only): model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
...@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only: if not forward_only:
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] forward_data_store = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
...@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return forward_data_store
...@@ -43,7 +43,7 @@ from megatron.model import ModelType ...@@ -43,7 +43,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
...@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type, model_type,
forward_step_func, forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add `lm-loss: value`. We also require that this function add
`batch generator` to the timers class. `batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments. to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It args_defaults: a dictionary from argument-name to argument-value. It
...@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start() timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type) model_type)
timers('model-and-optimizer-setup').stop() timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate ' print_datetime('after model, optimizer, and learning rate '
...@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done') print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, process_non_loss_data_func,
True)
def update_train_iters(args): def update_train_iters(args):
...@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model return model
def get_learning_rate_scheduler(optimizer): def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
...@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer): ...@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters: if args.train_iters:
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
elif args.train_samples: elif args.train_samples:
# We need to set training iters for later use. Technically # We need to set training iters for later use. Technically
...@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer): ...@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters(args) update_train_iters(args)
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else: else:
warmup_steps = args.lr_warmup_samples lr_warmup_steps = args.lr_warmup_samples
else: else:
raise Exception( raise Exception(
'either train-iters or train-samples should be provided.') 'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR( opt_param_scheduler = OptimizerParamScheduler(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_steps, lr_warmup_steps=lr_warmup_steps,
decay_steps=decay_steps, lr_decay_steps=lr_decay_steps,
decay_style=args.lr_decay_style, lr_decay_style=args.lr_decay_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, start_wd=args.start_weight_decay,
override_lr_scheduler=args.override_lr_scheduler) end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
return lr_scheduler wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
def setup_model_and_optimizer(model_provider_func, model_type):
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
...@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers() timers = get_timers()
...@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time. # max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').start() timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('load-checkpoint').stop() timers('load-checkpoint').stop()
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
...@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if args.fp16: if args.fp16:
optimizer.reload_model_params() optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, opt_param_scheduler):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator,
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
lr_scheduler.step(increment=increment) opt_param_scheduler.step(increment=increment)
skipped_iter = 0 skipped_iter = 0
else: else:
skipped_iter = 1 skipped_iter = 1
...@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers() timers = get_timers()
# Extra barrier is added to make sure # Extra barrier is added to make sure
# all ranks report the max time. # all ranks report the max time.
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').start() timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
timers('save-checkpoint').stop() timers('save-checkpoint').stop()
timers.log(['save-checkpoint']) timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) opt_param_scheduler)
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
...@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer, check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
...@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler = get_signal_handler() signal_handler = get_signal_handler()
if any(signal_handler.signals_received()): if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.') print_datetime('exiting program after receiving SIGTERM.')
sys.exit() sys.exit()
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Exiting based on duration # Exiting based on duration
...@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done: if done:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time)) print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit() sys.exit()
...@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
...@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False): def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
...@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode. # Move model back to the train mode.
for model_module in model: for model_module in model:
model_module.train() model_module.train()
...@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key), writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples) ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1 length = len(string) + 1
print_rank_last('-' * length) print_rank_last('-' * length)
print_rank_last(string) print_rank_last(string)
......
...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier() torch.distributed.barrier()
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!") print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
autoresume.request_resume() autoresume.request_resume()
......
...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset, ...@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step, def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback): train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model.""" """Train the model."""
args = get_args() args = get_args()
...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1 iteration += 1
...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \ if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0): (iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler) optimizer, opt_param_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True saved_checkpoint = True
# Evaluation # Evaluation
...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration)) print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop() timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model. # Finetune the model.
if args.epochs > 0: if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step, _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback) train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate. # Or just evaluate.
else: else:
......
...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): ...@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _train( def _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
...@@ -179,7 +179,7 @@ def _train( ...@@ -179,7 +179,7 @@ def _train(
# Train for one step. # Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler forward_step, batch, model, optimizer, opt_param_scheduler
) )
iteration += 1 iteration += 1
...@@ -206,7 +206,7 @@ def _train( ...@@ -206,7 +206,7 @@ def _train(
iteration % args.adlr_autoresume_interval == 0 iteration % args.adlr_autoresume_interval == 0
): ):
check_adlr_autoresume_termination( check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler iteration, model, optimizer, opt_param_scheduler
) )
# Checkpointing # Checkpointing
...@@ -215,7 +215,7 @@ def _train( ...@@ -215,7 +215,7 @@ def _train(
and args.save_interval and args.save_interval
and iteration % args.save_interval == 0 and iteration % args.save_interval == 0
): ):
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0: if args.eval_interval and iteration % args.eval_interval == 0:
...@@ -231,7 +231,7 @@ def _train( ...@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch. # Callback at the end of each epoch.
if end_of_epoch_callback is not None: if end_of_epoch_callback is not None:
...@@ -266,7 +266,7 @@ def finetune( ...@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start() timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop() timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for # If pretrained checkpoint is provided and we have not trained for
...@@ -300,7 +300,7 @@ def finetune( ...@@ -300,7 +300,7 @@ def finetune(
_train( _train(
model, model,
optimizer, optimizer,
lr_scheduler, opt_param_scheduler,
forward_step, forward_step,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment