Commit d8c85650 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

second phase of vision code merge

parent 798b6a64
...@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -246,6 +246,10 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
if args.weight_decay is not None:
args.start_wd = args.weight_decay
args.end_wd = args.weight_decay
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm. # Persistent fused layer norm.
...@@ -395,6 +399,13 @@ def _add_regularization_args(parser): ...@@ -395,6 +399,13 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.') help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float, default=0.01,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float, default=0.01,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0, group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.') help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9, group.add_argument('--adam-beta1', type=float, default=0.9,
......
...@@ -24,6 +24,7 @@ class AnnealingLR(object): ...@@ -24,6 +24,7 @@ class AnnealingLR(object):
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style, warmup_steps, decay_steps, decay_style,
start_wd, end_wd, wd_incr_style,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
...@@ -43,6 +44,13 @@ class AnnealingLR(object): ...@@ -43,6 +44,13 @@ class AnnealingLR(object):
self.decay_style = decay_style self.decay_style = decay_style
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_style = wd_incr_style
self.override_lr_scheduler = override_lr_scheduler self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler: if self.override_lr_scheduler:
...@@ -51,10 +59,33 @@ class AnnealingLR(object): ...@@ -51,10 +59,33 @@ class AnnealingLR(object):
# Set the learning rate # Set the learning rate
self.step(0) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_wd(self):
if self.num_steps > self.decay_steps:
return self.end_wd
if self.wd_incr_style == 'constant':
assert self.start_wd == self.end_wd
return self.end_wd
decay_ratio = float(self.num_steps) / float(self.decay_steps)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == 'linear':
coeff = decay_ratio
elif self.wd_incr_style == 'cosine':
coeff = 0.5 * (math.cos(math.pi * (1 - decay_ratio)) + 1.0)
else:
raise Exception('{} weight decay increment style is not supported.'.format(
self.wd_incr_style))
return self.start_wd + coeff * delta_wd
def get_lr(self): def get_lr(self):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...@@ -95,8 +126,10 @@ class AnnealingLR(object): ...@@ -95,8 +126,10 @@ class AnnealingLR(object):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
self.num_steps += increment self.num_steps += increment
new_lr = self.get_lr() new_lr = self.get_lr()
new_wd = self.get_wd()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr * group['lr_mult']
group['weight_decay'] = new_wd * group['wd_mult']
def state_dict(self): def state_dict(self):
......
...@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -43,6 +43,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
""" """
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -407,12 +430,14 @@ class ParallelTransformerLayer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder, layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.layer_type = layer_type self.layer_type = layer_type
self.drop_path_rate = drop_path_rate
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
...@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -435,6 +460,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type) attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate)
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
...@@ -478,6 +504,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -478,6 +504,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = hidden_states residual = hidden_states
if self.drop_path_rate == 0.0:
# jit scripting for a nn.module (with dropout) is not # jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two # trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying # different nn.functional routines to account for varying
...@@ -497,6 +524,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -497,6 +524,11 @@ class ParallelTransformerLayer(MegatronModule):
attention_bias.expand_as(residual), attention_bias.expand_as(residual),
residual, residual,
self.hidden_dropout) self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
...@@ -532,6 +564,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -532,6 +564,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
if self.drop_path_rate == 0.0:
# re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
...@@ -539,6 +572,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -539,6 +572,11 @@ class ParallelTransformerLayer(MegatronModule):
mlp_bias.expand_as(residual), mlp_bias.expand_as(residual),
residual, residual,
self.hidden_dropout) self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output return output
...@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule): ...@@ -549,7 +587,8 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True): pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule): ...@@ -558,6 +597,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
...@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule): ...@@ -568,6 +608,8 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.num_layers)]
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
...@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule): ...@@ -575,7 +617,8 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.dpr[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
......
...@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler ...@@ -23,35 +23,67 @@ from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(modules): def get_param_groups(modules,
"""Divide params into with-weight-decay and without-weight-decay groups. no_weight_decay_cond,
Layernorms and baises will have no weight decay but the rest will. scale_lr_cond,
lr_mult):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
""" """
wd_no_scale_lr = []
weight_decay_params = {'params': []} wd_scale_lr = []
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_wd_no_scale_lr = []
no_wd_scale_lr = []
for module in modules: for module in modules:
for module_ in module.modules(): for name, param in module.named_parameters():
if isinstance(module_, LayerNorm): if not param.requires_grad:
no_weight_decay_params['params'].extend( continue
[p for p in list(module_._parameters.values())
if p is not None and p.requires_grad]) if no_weight_decay_cond is not None:
no_wd = no_weight_decay_cond(name, param)
else: else:
weight_decay_params['params'].extend( no_wd = name.endswith(".bias") or len(param.shape) == 1
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and p.requires_grad and n == 'bias'])
return weight_decay_params, no_weight_decay_params if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
else:
scale_lr = False
if not no_wd and not scale_lr:
wd_no_scale_lr.append(param)
elif not no_wd and scale_lr:
wd_scale_lr.append(param)
elif no_wd and not scale_lr:
no_wd_no_scale_lr.append(param)
else:
no_wd_scale_lr.append(param)
def get_megatron_optimizer(model): param_groups = []
if len(wd_no_scale_lr):
param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0})
if len(wd_scale_lr):
param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult})
if len(no_wd_no_scale_lr):
param_groups.append({'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0})
if len(no_wd_scale_lr):
param_groups.append({'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult})
return param_groups
def get_megatron_optimizer(model,
no_weight_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model) param_groups = get_param_groups(model,
no_weight_decay_cond,
scale_lr_cond,
lr_mult)
if args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
......
...@@ -91,7 +91,12 @@ def custom_backward(output, grad_output): ...@@ -91,7 +91,12 @@ def custom_backward(output, grad_output):
accumulate_grad=True, accumulate_grad=True,
) )
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func,
data_iterator,
model,
input_tensor,
forward_data_store,
collect_non_loss_data=False):
"""Forward step for passed-in model. """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise If first stage, input tensor is obtained from data_iterator, otherwise
...@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -113,10 +118,15 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(data_iterator, model) output_tensor, loss_func = forward_step_func(data_iterator, model)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor) output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches() output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced) forward_data_store.append(loss_reduced)
else:
data = loss_func(output_tensor, non_loss_data=True)
forward_data_store.append(data)
timers('forward-compute').stop() timers('forward-compute').stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
...@@ -203,8 +213,12 @@ def dummy_handler(): ...@@ -203,8 +213,12 @@ def dummy_handler():
pass pass
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run forward and backward passes with no pipeline parallelism """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication). (no inter-stage communication).
...@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -216,35 +230,41 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
context_handler = model.no_sync context_handler = model.no_sync
losses_reduced = [] forward_data_store = []
input_tensor, output_tensor_grad = None, None input_tensor, output_tensor_grad = None, None
with context_handler(): with context_handler():
for i in range(get_num_microbatches() - 1): for i in range(get_num_microbatches() - 1):
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator,
input_tensor, losses_reduced) model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
return losses_reduced return forward_data_store
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, def forward_backward_pipelining_with_interleaving(forward_step_func,
optimizer, timers, forward_only): data_iterator, model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise.""" Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
losses_reduced = [] forward_data_store = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_tensor_grads = [[] for _ in range(len(model))]
...@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -304,7 +324,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = forward_step(forward_step_func, output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id], data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, losses_reduced) input_tensor,
forward_data_store,
collect_non_loss_data)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
...@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -471,7 +493,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
timers=timers)) timers=timers))
return losses_reduced return forward_data_store
def get_tensor_shapes(rank, model_type): def get_tensor_shapes(rank, model_type):
...@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers): ...@@ -568,9 +590,13 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, timers):
return input_tensors return input_tensors
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator, def forward_backward_pipelining_without_interleaving(forward_step_func,
model, optimizer, timers, data_iterator,
forward_only): model,
optimizer,
timers,
forward_only,
collect_non_loss_data=False):
"""Run non-interleaved 1F1B schedule, with communication between pipeline """Run non-interleaved 1F1B schedule, with communication between pipeline
stages. stages.
...@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -605,13 +631,14 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if not forward_only: if not forward_only:
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] forward_data_store = []
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers) input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
if not forward_only: if not forward_only:
...@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -630,7 +657,8 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only: if forward_only:
send_forward(output_tensor, send_tensor_shapes, timers=timers) send_forward(output_tensor, send_tensor_shapes, timers=timers)
...@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -679,4 +707,4 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers) send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
return losses_reduced return forward_data_store
...@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -65,6 +65,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
model_type, model_type,
forward_step_func, forward_step_func,
process_non_loss_data_func=None,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}): args_defaults={}):
"""Main training program. """Main training program.
...@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -145,14 +146,16 @@ def pretrain(train_valid_test_dataset_provider,
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator,
process_non_loss_data_func)
print_datetime('after training is done') print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
...@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -162,7 +165,8 @@ def pretrain(train_valid_test_dataset_provider,
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, process_non_loss_data_func,
True)
def update_train_iters(args): def update_train_iters(args):
...@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer): ...@@ -333,13 +337,20 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
start_wd=args.start_wd,
end_wd=args.end_wd,
wd_incr_style=args.wd_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler return lr_scheduler
def setup_model_and_optimizer(model_provider_func, model_type): def setup_model_and_optimizer(model_provider_func,
model_type,
no_wd_decay_cond=None,
scale_lr_cond=None,
lr_mult=1.0):
"""Setup model and optimizer.""" """Setup model and optimizer."""
args = get_args() args = get_args()
...@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type): ...@@ -347,7 +358,8 @@ def setup_model_and_optimizer(model_provider_func, model_type):
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
...@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler): ...@@ -659,7 +671,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator,
process_non_loss_data_func):
"""Train the model function.""" """Train the model function."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -716,7 +729,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
prefix = 'iteration {}'.format(iteration) prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, process_non_loss_data_func,
False)
# Checkpointing # Checkpointing
saved_checkpoint = False saved_checkpoint = False
...@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -762,7 +776,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
return iteration return iteration
def evaluate(forward_step_func, data_iterator, model, verbose=False): def evaluate(forward_step_func,
data_iterator,
model,
process_non_loss_data_func,
verbose=False):
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
...@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -799,6 +817,12 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.micro_batch_size \ * args.micro_batch_size \
* get_num_microbatches() * get_num_microbatches()
collected_non_loss_data = None
if process_non_loss_data_func is not None and is_last_rank():
collected_non_loss_data = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True, collect_non_loss_data=True)
# Move model back to the train mode. # Move model back to the train mode.
for model_module in model: for model_module in model:
model_module.train() model_module.train()
...@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -806,16 +830,19 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters * get_num_microbatches() total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
return total_loss_dict return total_loss_dict, collected_non_loss_data
def evaluate_and_print_results(prefix, forward_step_func, def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model, data_iterator, model,
iteration, verbose=False): iteration, process_non_loss_data_func,
verbose=False):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
args = get_args() args = get_args()
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose) total_loss_dict, collected_non_loss_data = evaluate(
forward_step_func, data_iterator, model,
process_non_loss_data_func, verbose)
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict: for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
...@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -834,6 +861,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar('{} validation ppl vs samples'.format(key), writer.add_scalar('{} validation ppl vs samples'.format(key),
ppl, args.consumed_train_samples) ppl, args.consumed_train_samples)
if process_non_loss_data_func is not None and writer and is_last_rank():
process_non_loss_data_func(collected_non_loss_data, iteration, writer)
length = len(string) + 1 length = len(string) + 1
print_rank_last('-' * length) print_rank_last('-' * length)
print_rank_last(string) print_rank_last(string)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment