Commit e724785f authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vision-merge' into 'main'

second phase of vision code merge

See merge request ADLR/megatron-lm!381
parents e156d2fe 2b628f96
......@@ -246,6 +246,15 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
......@@ -395,6 +404,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......@@ -561,13 +577,13 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true',
group.add_argument('--override-opt_param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
......
......@@ -167,7 +167,7 @@ def get_rng_state():
return rng_state_list
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -198,8 +198,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
......@@ -295,7 +295,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
......@@ -394,8 +394,11 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
......
......@@ -42,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -406,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
......@@ -434,6 +458,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
......@@ -477,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
......@@ -531,13 +562,19 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
......@@ -548,7 +585,8 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True):
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -557,6 +595,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
......@@ -567,6 +606,8 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
......@@ -574,7 +615,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
......
......@@ -23,35 +23,68 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
def get_param_groups(modules,
no_weight_decay_cond,
scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
wd_no_scale_lr = []
wd_scale_lr = []
no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules:
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None and p.requires_grad])
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n == 'bias'])
# do not regularize biases nor Norm parameters
no_wd = name.endswith(".bias") or len(param.shape) == 1
return weight_decay_params, no_weight_decay_params
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model):
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args()
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
......
......@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate decay functions."""
"""Learning rate decay and weight decay incr functions."""
import math
from megatron import print_rank_0
class AnnealingLR(object):
"""Anneals the learning rate."""
class OptimizerParamScheduler(object):
"""Anneals learning rate and weight decay"""
def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
lr_warmup_steps, lr_decay_steps, lr_decay_style,
start_wd, end_wd, wd_incr_steps, wd_incr_style,
use_checkpoint_opt_param_scheduler=True,
override_opt_param_scheduler=False):
# Class values.
self.optimizer = optimizer
......@@ -35,24 +36,55 @@ class AnnealingLR(object):
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps
self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
self.lr_decay_steps = lr_decay_steps
assert self.lr_decay_steps > 0
assert self.lr_warmup_steps < self.lr_decay_steps
self.lr_decay_style = lr_decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.lr_decay_style))
def get_wd(self):
""" Weight decay incr functions"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = incr_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self):
......@@ -60,33 +92,33 @@ class AnnealingLR(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)
float(self.lr_warmup_steps)
# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
if self.lr_decay_style == 'constant':
return self.max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
# For any steps larger than `self.lr_decay_steps`, use `self.min_lr`.
if self.num_steps > self.lr_decay_steps:
return self.min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
if self.decay_style == 'linear':
if self.lr_decay_style == 'linear':
coeff = (1.0 - decay_ratio)
elif self.decay_style == 'cosine':
elif self.lr_decay_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
self.lr_decay_style))
return self.min_lr + coeff * delta_lr
......@@ -95,18 +127,24 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
group['lr'] = new_lr * group.get('lr_mult', 1.0)
group['weight_decay'] = new_wd * group.get('wd_mult', 1.0)
def state_dict(self):
state_dict = {
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'lr_warmup_steps': self.lr_warmup_steps,
'num_steps': self.num_steps,
'decay_style': self.decay_style,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
'lr_decay_style': self.lr_decay_style,
'lr_decay_steps': self.lr_decay_steps,
'min_lr': self.min_lr,
'start_wd': self.start_wd,
'end_wd': self.end_wd,
'wd_incr_style': self.wd_incr_style,
'wd_incr_steps': self.wd_incr_steps
}
return state_dict
......@@ -114,13 +152,13 @@ class AnnealingLR(object):
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
if self.override_opt_param_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
if not self.use_checkpoint_lr_scheduler:
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, \
f'AnnealingLR: class input value {cls_value} and checkpoint' \
f'OptimizerParamScheduler: class input value {cls_value} and checkpoint' \
f'value {sd_value} for {name} do not match'
print_rank_0(' > using checkpoint value {} for {}'.format(sd_value,
name))
......@@ -140,25 +178,57 @@ class AnnealingLR(object):
'minimum learning rate')
if 'warmup_iter' in sd:
warmup_steps_ = sd['warmup_iter']
lr_warmup_steps_ = sd['warmup_iter']
elif 'warmup_steps' in sd:
lr_warmup_steps_ = sd['warmup_steps']
else:
warmup_steps_ = sd['warmup_steps']
self.warmup_steps = self._check_and_set(self.warmup_steps,
warmup_steps_,
lr_warmup_steps_ = sd['lr_warmup_steps']
self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps,
lr_warmup_steps_,
'warmup iterations')
if 'end_iter' in sd:
decay_steps_ = sd['end_iter']
lr_decay_steps_ = sd['end_iter']
elif 'decay_steps' in sd:
lr_decay_steps_ = sd['decay_steps']
else:
decay_steps_ = sd['decay_steps']
self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_,
lr_decay_steps_ = sd['lr_decay_steps']
self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_,
'total number of iterations')
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
if 'decay_style' in sd:
lr_decay_style_ = sd['decay_style']
else:
lr_decay_style_ = sd['lr_decay_style']
self.lr_decay_style = self._check_and_set(self.lr_decay_style,
lr_decay_style_,
'learning rate decay style')
if 'num_iters' in sd:
num_steps = sd['num_iters']
else:
num_steps = sd['num_steps']
self.step(increment=num_steps)
if 'start_wd' in sd:
self.start_wd = self._check_and_set(self.start_wd,
sd['start_wd'],
"start weight decay")
self.end_wd = self._check_and_set(self.end_wd,
sd['end_wd'],
"end weight decay")
self.wd_incr_steps = self._check_and_set(self.wd_incr_steps,
sd['wd_incr_steps'],
"total number of weight decay iterations")
self.wd_incr_style = self._check_and_set(self.wd_incr_style,
sd['wd_incr_style'],
"weight decay incr style")
......@@ -98,7 +98,12 @@ def custom_backward(output, grad_output):
)
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
......@@ -120,10 +125,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder)
......@@ -206,8 +216,12 @@ def dummy_handler():
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
def forward_backward_no_pipelining(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
......@@ -219,35 +233,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
forward_data_store = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
def forward_backward_pipelining_with_interleaving(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
......@@ -307,7 +327,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
......@@ -474,7 +496,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced
return forward_data_store
def get_tensor_shapes(rank, model_type):
......@@ -571,9 +593,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
def forward_backward_pipelining_without_interleaving(forward_step_func,
data_iterator,
model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
......@@ -608,13 +634,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
......@@ -633,7 +660,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
......@@ -682,4 +710,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced
return forward_data_store
......@@ -43,7 +43,7 @@ from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
......@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
......@@ -86,6 +87,10 @@ def pretrain(train_valid_test_dataset_provider,
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
process_non_loss_data_func: a function to post process outputs of the
network. It can be used for dumping output tensors (e.g images) to
tensorboard. It takes `collected data`(list of tensors),
`current iteration index` and `tensorboard writer` as arguments.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
......@@ -113,7 +118,7 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
model_type)
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
......@@ -144,25 +149,28 @@ def pretrain(train_valid_test_dataset_provider,
iteration = 0
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
0, process_non_loss_data_func,
True)
def update_train_iters(args):
......@@ -296,7 +304,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
return model
def get_learning_rate_scheduler(optimizer):
def get_optimizer_param_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
......@@ -304,11 +312,12 @@ def get_learning_rate_scheduler(optimizer):
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size
lr_decay_steps = args.lr_decay_iters * args.global_batch_size
wd_incr_steps = args.train_iters * args.global_batch_size
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
......@@ -317,29 +326,38 @@ def get_learning_rate_scheduler(optimizer):
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples
lr_decay_steps = args.lr_decay_samples
wd_incr_steps = args.train_samples
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
else:
warmup_steps = args.lr_warmup_samples
lr_warmup_steps = args.lr_warmup_samples
else:
raise Exception(
'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR(
opt_param_scheduler = OptimizerParamScheduler(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def setup_model_and_optimizer(model_provider_func, model_type):
lr_warmup_steps=lr_warmup_steps,
lr_decay_steps=lr_decay_steps,
lr_decay_style=args.lr_decay_style,
start_wd=args.start_weight_decay,
end_wd=args.end_weight_decay,
wd_incr_steps=wd_incr_steps,
wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
override_opt_param_scheduler=args.override_opt_param_scheduler)
return opt_param_scheduler
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer."""
args = get_args()
......@@ -347,9 +365,10 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
timers = get_timers()
......@@ -357,7 +376,7 @@ def setup_model_and_optimizer(model_provider_func, model_type):
# max time.
torch.distributed.barrier()
timers('load-checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
......@@ -376,11 +395,11 @@ def setup_model_and_optimizer(model_provider_func, model_type):
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler
return model, optimizer, opt_param_scheduler
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
model, optimizer, opt_param_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -456,7 +475,7 @@ def train_step(forward_step_func, data_iterator,
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
opt_param_scheduler.step(increment=increment)
skipped_iter = 0
else:
skipped_iter = 1
......@@ -646,20 +665,21 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save-checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
timers('save-checkpoint').stop()
timers.log(['save-checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function."""
args = get_args()
timers = get_timers()
......@@ -687,7 +707,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator,
model,
optimizer,
lr_scheduler)
opt_param_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
......@@ -708,7 +728,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
......@@ -716,7 +736,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
iteration, process_non_loss_data_func,
False)
# Checkpointing
saved_checkpoint = False
......@@ -724,14 +745,14 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
signal_handler = get_signal_handler()
if any(signal_handler.signals_received()):
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
print_datetime('exiting program after receiving SIGTERM.')
sys.exit()
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
saved_checkpoint = True
# Exiting based on duration
......@@ -745,7 +766,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
......@@ -753,7 +774,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
opt_param_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
......@@ -762,7 +783,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False):
def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation."""
args = get_args()
......@@ -799,6 +824,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode.
for model_module in model:
model_module.train()
......@@ -806,16 +837,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict
return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, verbose=False):
iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
......@@ -834,6 +868,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
......
......@@ -126,7 +126,7 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
......@@ -136,7 +136,7 @@ def check_adlr_autoresume_termination(iteration, model,
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
......
......@@ -154,7 +154,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset,
return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step,
def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model."""
args = get_args()
......@@ -195,7 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
......@@ -215,13 +215,13 @@ def _train(model, optimizer, lr_scheduler, forward_step,
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler)
optimizer, opt_param_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True
# Evaluation
......@@ -234,14 +234,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
......@@ -279,7 +279,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
......@@ -307,7 +307,7 @@ def finetune(train_valid_datasets_provider, model_provider,
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step,
_train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate.
else:
......
......@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _train(
model,
optimizer,
lr_scheduler,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
......@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler
forward_step, batch, model, optimizer, opt_param_scheduler
)
iteration += 1
......@@ -206,7 +206,7 @@ def _train(
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler
iteration, model, optimizer, opt_param_scheduler
)
# Checkpointing
......@@ -215,7 +215,7 @@ def _train(
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
......@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
......@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
......@@ -300,7 +300,7 @@ def finetune(
_train(
model,
optimizer,
lr_scheduler,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment