Commit d8c85650 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

second phase of vision code merge

parent 798b6a64
......@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay is not None:
args.start_wd = args.weight_decay
args.end_wd = args.weight_decay
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
......@@ -395,6 +399,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float, default=0.01,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float, default=0.01,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
......
......@@ -24,6 +24,7 @@ class AnnealingLR(object):
def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
start_wd, end_wd, wd_incr_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
......@@ -43,6 +44,13 @@ class AnnealingLR(object):
self.decay_style = decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_style = wd_incr_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
......@@ -51,10 +59,33 @@ class AnnealingLR(object):
# Set the learning rate
self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_wd(self):
if self.num_steps > self.decay_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
decay_ratio = float(self.num_steps) / float(self.decay_steps)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = decay_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - decay_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
......@@ -95,8 +126,10 @@ class AnnealingLR(object):
"""Set lr for all parameters groups."""
self.num_steps += increment
new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
group['lr'] = new_lr * group['lr_mult']
group['weight_decay'] = new_wd * group['wd_mult']
def state_dict(self):
......
......@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.drop_path_rate = drop_path_rate
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate)
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
......@@ -478,25 +504,31 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
if self.drop_path_rate == 0.0:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
......@@ -532,13 +564,19 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
if self.drop_path_rate == 0.0:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
......@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True):
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method
......@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers)]
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
......@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.dpr[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
......
......@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
def get_param_groups(modules,
no_weight_decay_cond,
scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
wd_no_scale_lr = []
wd_scale_lr = []
no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules:
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None and p.requires_grad])
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n == 'bias'])
no_wd = name.endswith(".bias") or len(param.shape) == 1
return weight_decay_params, no_weight_decay_params
if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model):
param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args()
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
......
......@@ -91,7 +91,12 @@ def custom_backward(output, grad_output):
accumulate_grad=True,
)
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
......@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder)
......@@ -203,8 +213,12 @@ def dummy_handler():
pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
def forward_backward_no_pipelining(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
......@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP):
context_handler = model.no_sync
losses_reduced = []
forward_data_store = []
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
output_tensor = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced
return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
def forward_backward_pipelining_with_interleaving(forward_step_func,
data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
forward_data_store = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
......@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass
......@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape,
timers=timers))
return losses_reduced
return forward_data_store
def get_tensor_shapes(rank, model_type):
......@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, timers,
forward_only):
def forward_backward_pipelining_without_interleaving(forward_step_func,
data_iterator,
model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
......@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only:
......@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers)
......@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced
return forward_data_store
......@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider,
model_type,
forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None,
args_defaults={}):
"""Main training program.
......@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider,
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider,
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
0, True)
0, process_non_loss_data_func,
True)
def update_train_iters(args):
......@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
start_wd=args.start_wd,
end_wd=args.end_wd,
wd_incr_style=args.wd_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def setup_model_and_optimizer(model_provider_func, model_type):
def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer."""
args = get_args()
......@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer)
......@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function."""
args = get_args()
timers = get_timers()
......@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, False)
iteration, process_non_loss_data_func,
False)
# Checkpointing
saved_checkpoint = False
......@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False):
def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation."""
args = get_args()
......@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \
* get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode.
for model_module in model:
model_module.train()
......@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict
return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
iteration, verbose=False):
iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen."""
args = get_args()
writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
......@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1
print_rank_last('-' * length)
print_rank_last(string)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment