Commit 27e14f82 authored by Mohammad's avatar Mohammad
Browse files

refactored training

parent 3f58649b
......@@ -234,6 +234,8 @@ def _add_mixed_precision_args(parser):
'attention-softmax-in-fp32 to true')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32.')
group.add_argument('--fp32-allreduce', action='store_true',
help='All-reduce in fp32')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--loss-scale', type=float, default=None,
......
......@@ -13,62 +13,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities"""
"""Pretrain utilities."""
from datetime import datetime
import math
import sys
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from megatron.global_vars import get_args
from megatron.global_vars import get_timers
from megatron.global_vars import get_tensorboard_writer
from megatron.global_vars import get_adlr_autoresume
from megatron.initialize import initialize_megatron
from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.checkpointing import load_checkpoint
from megatron import print_rank_0
from megatron.utils import report_memory
from megatron.checkpointing import save_checkpoint
def run(top_level_message, train_val_test_data_provider,
model_provider, forward_step_func, extra_args_provider=None,
args_defaults={}):
def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
extra_args_provider=None, args_defaults={}):
"""Main training program.
This function will run the followings in the order provided:
1) get input arguments.
2) initialize distributed and seeds.
1) initialize Megatron.
2) setup model, optimizer and lr schedule using the model_provider.
3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider.
5) train the modle using the forward_step_func.
4) train the modle using the forward_step_func.
Arguments:
top_level_message: a meesage to print at the top of the run.
train_val_test_data_provider: a function that takes `args` as input
and returns `train, val, test` dataloaders. Note that args are
passed and can be modified in case we need to use some parameters
later. For example, we can set vocab size using
args.vocab_size = ...
and later use this value in `model_provider`.
model_provider: a function that takes `args` and returns a vanilla
version of the model. By vanilla we mean a simple model on cpu
with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator`, `model`,
`args`, and `timers` and returns a `loss` scalar with a dictionary
with key:values being the info we would like to monitor during
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
train_val_test_data_provider: a function that builds datasets
and returns `train, val, test` dataloaders.
model_provider: a function that returns a vanilla version of the
model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator` and `model`,
and returns a `loss` scalar with a dictionary with key:values being
the info we would like to monitor during training, for example
`lm-loss: value`. We also require that this function add
`batch generator` to the timers class.
extra_args_provider: a function that takes a parser and adds arguments
to it. It is used for programs to add their own arguments.
args_defaults: a dictionary from argument-name to argument-value. It
to set already parse arguments.
"""
# Initalize and get arguments, timers, and Tensorboard writer.
......@@ -76,36 +71,44 @@ def run(top_level_message, train_val_test_data_provider,
args_defaults=args_defaults)
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
# Data stuff.
timers('train/valid/test dataset').start()
train_data, val_data, test_data = train_val_test_data_provider()
timers('train/valid/test dataset').stop()
# Train, validation, and test data.
timers('train/valid/test dataloader').start()
train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data,
test_data,
args)
test_data)
timers('train/valid/test dataloader').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['model and optimizer', 'train/valid/test dataset',
'train/valid/test dataloader'])
print_rank_0('training ...')
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, _ = train(forward_step_func, model,
optimizer, lr_scheduler,
train_data_iterator, val_data_iterator,
timers, args, writer)
iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator)
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model,
args, writer, iteration,
timers, False)
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......@@ -115,14 +118,15 @@ def run(top_level_message, train_val_test_data_provider,
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
args, None, 0, timers, True)
0, True)
def get_model(model_provider_func, args):
def get_model(model_provider_func):
"""Build the model."""
args = get_args()
# Build model on cpu.
model = model_provider_func(args)
model = model_provider_func()
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
......@@ -140,26 +144,24 @@ def get_model(model_provider_func, args):
# Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
model = torchDDP(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
if args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
model = LocalDDP(model)
return model
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
sys.exit()
def get_optimizer(model, args):
def get_optimizer(model):
"""Set up the optimizer."""
args = get_args()
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)):
while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
model = model.module
param_groups = get_params_for_weight_decay_optimization(model)
......@@ -170,8 +172,7 @@ def get_optimizer(model, args):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
if args.fp16:
......@@ -186,8 +187,9 @@ def get_optimizer(model, args):
return optimizer
def get_learning_rate_scheduler(optimizer, args):
def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
......@@ -211,12 +213,13 @@ def get_learning_rate_scheduler(optimizer, args):
return lr_scheduler
def setup_model_and_optimizer(model_provider_func, args):
def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
model = get_model(model_provider_func, args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
model = get_model(model_provider_func)
optimizer = get_optimizer(model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
......@@ -226,8 +229,10 @@ def setup_model_and_optimizer(model_provider_func, args):
return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss, args, timers):
def backward_step(optimizer, model, loss):
"""Backward step."""
args = get_args()
timers = get_timers()
# Backward pass.
optimizer.zero_grad()
......@@ -255,18 +260,20 @@ def backward_step(optimizer, model, loss, args, timers):
optimizer.clip_master_grads(args.clip_grad)
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
args, timers):
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers)
loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss, args, timers)
backward_step(optimizer, model, loss)
timers('backward').stop()
# Update parameters.
......@@ -285,7 +292,11 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
loss_scale, report_memory_flag, writer, args, timers):
loss_scale, report_memory_flag):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
writer = get_tensorboard_writer()
# Update losses.
for key in loss_dict:
......@@ -341,8 +352,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
train_data_iterator, val_data_iterator):
"""Train the model function."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
......@@ -361,8 +374,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
lr_scheduler)
skipped_iters += skipped_iter
iteration += 1
......@@ -370,8 +382,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, optimizer.loss_scale,
report_memory_flag, writer, args,
timers)
report_memory_flag)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \
......@@ -389,23 +400,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, args,
writer, iteration, timers, False)
val_data_iterator, model,
iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
sys.exit()
return iteration, skipped_iters
def evaluate(forward_step_func, data_iterator, model,
args, timers, verbose=False):
def evaluate(forward_step_func, data_iterator, model, verbose=False):
"""Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout.
model.eval()
......@@ -420,8 +431,7 @@ def evaluate(forward_step_func, data_iterator, model,
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
# Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model,
args, timers)
_, loss_dict = forward_step_func(data_iterator, model)
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
......@@ -437,11 +447,11 @@ def evaluate(forward_step_func, data_iterator, model,
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
args, writer, iteration,
timers, verbose=False):
iteration, verbose=False):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(forward_step_func, data_iterator, model,
args, timers, verbose)
writer = get_tensorboard_writer()
total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
......@@ -459,8 +469,9 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data, args):
def get_train_val_test_data_iterators(train_data, val_data, test_data):
"""Build train/validation/test iterators"""
args = get_args()
# Shift the start iterations.
if train_data is not None:
......
......@@ -18,24 +18,28 @@
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron.model import BertModel
from megatron import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.training import pretrain
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data_utils.samplers import DistributedBatchSampler
def model_provider(args):
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building BERT model ...')
model = BertModel(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
......@@ -46,7 +50,7 @@ def model_provider(args):
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
num_tokentypes=2,
parallel_output=True,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32)
......@@ -54,19 +58,17 @@ def model_provider(args):
return model
def get_batch(data_iterator, timers):
def get_batch(data_iterator):
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
......@@ -80,13 +82,14 @@ def get_batch(data_iterator, timers):
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def forward_step(data_iterator, model, args, timers):
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
= get_batch(data_iterator, timers)
= get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
......@@ -108,9 +111,10 @@ def forward_step(data_iterator, model, args, timers):
return loss, {'lm loss': reduced_losses[0], 'sop loss': reduced_losses[1]}
def get_train_val_test_data(args):
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
......@@ -202,6 +206,6 @@ if __name__ == "__main__":
'tokenizer_type': 'BertWordPieceLowerCase'})
exit()
'''
run('Pretrain BERT model', get_train_val_test_data,
model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
pretrain(get_train_val_test_data,
model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -17,6 +17,10 @@
import torch
from megatron import get_args
from megatron import get_timers
from configure_data import configure_data
from gpt2_data_loader import make_gpt2_dataloaders
from megatron import mpu
......@@ -25,15 +29,16 @@ from megatron.utils import get_ltor_masks_and_position_ids
from megatron import print_rank_0
from megatron.utils import reduce_losses
from megatron.utils import vocab_size_with_padding
from megatron.training import run
from megatron.training import pretrain
def model_provider(args):
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building GPT2 model ...')
model = GPT2Model(num_layers=args.num_layers,
vocab_size=args.vocab_size,
vocab_size=args.padded_vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
......@@ -50,20 +55,19 @@ def model_provider(args):
return model
def get_batch(data_iterator, args, timers):
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
timers('data loader').start()
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
timers('data loader').stop()
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
......@@ -85,8 +89,10 @@ def get_batch(data_iterator, args, timers):
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model, args, timers):
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
......@@ -107,9 +113,10 @@ def forward_step(data_iterator, model, args, timers):
return loss, {'lm loss': reduced_loss[0]}
def get_train_val_test_data(args):
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args()
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
......@@ -162,5 +169,5 @@ def get_train_val_test_data(args):
if __name__ == "__main__":
run('Pretrain GPT-2 model', get_train_val_test_data,
model_provider, forward_step)
pretrain(get_train_val_test_data,
model_provider, forward_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment