Commit a31833ce authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Fix some bugs, add exit-duration capability

parent 51315905
......@@ -26,6 +26,7 @@ from .package_info import (
)
from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
......
......@@ -326,6 +326,8 @@ def _add_training_args(parser):
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
......
......@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
......
......@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup)
train_datasets.append(train_ds)
valid_datasets.append(valid_ds)
test_datasets.append(test_ds)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = BlendableDataset(test_datasets, weights)
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset)
......
......@@ -43,8 +43,13 @@ def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def update_num_microbatches(consumed_samples):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples)
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
def get_tokenizer():
......
......@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None
def get(self):
return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod
def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):
pass
......@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):
pass
......@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
self.update(0, False)
def update(self, consumed_samples):
def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = current_global_batch_size // \
self.current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \
'data parallel size ({})'.format(self.current_global_batch_size,
self.micro_batch_size,
self.data_parallel_size)
self.num_micro_batches = self.current_global_batch_size // \
self.micro_batch_times_data_parallel_size
......@@ -18,6 +18,10 @@
from datetime import datetime
import math
import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
......@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam
from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import update_num_microbatches
from megatron import mpu
......@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program.
......@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time took to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
args = get_args()
timers = get_timers()
......@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff.
timers('train/valid/test data iterators').start()
......@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
= build_train_valid_test_data_iterators(
train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are build')
# Print setup timing.
print_rank_0('done with setups ...')
......@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid:
prefix = 'the end of training for val data'
......@@ -132,13 +160,11 @@ def update_train_iters(args):
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples)
consumed_samples += get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_current_global_batch_size()
iterations += 1
# Reset
update_num_microbatches(0)
update_num_microbatches(0, consistency_check=False)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
......@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load checkpoint').stop()
timers.log(['load checkpoint'])
else:
args.iteration = 0
......@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Tensorboard values.
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
writer.add_scalar('learning_rate-samples', learning_rate,
args.consumed_train_samples)
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
writer.add_scalar('batch_size-iterations', batch_size, iteration)
writer.add_scalar('batch_size-samples', batch_size,
args.consumed_train_samples)
for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration)
writer.add_scalar(key, loss_dict[key] + '-iterations', iteration)
writer.add_scalar(key, loss_dict[key] + '-samples',
args.consumed_train_samples)
if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration)
writer.add_scalar('loss_scale-iterations', loss_scale, iteration)
writer.add_scalar('loss_scale-samples', loss_scale,
args.consumed_train_samples)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
......@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' consumed samples {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
......@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save checkpoint').stop()
timers.log(['save checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
......@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration = args.iteration
timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
......@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
......@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator, model,
iteration, False)
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
return iteration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment