Commit a31833ce authored by mshoeybi's avatar mshoeybi Committed by Deepak Narayanan
Browse files

Fix some bugs, add exit-duration capability

parent 51315905
...@@ -26,6 +26,7 @@ from .package_info import ( ...@@ -26,6 +26,7 @@ from .package_info import (
) )
from .global_vars import get_args from .global_vars import get_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches from .global_vars import get_num_microbatches
from .global_vars import update_num_microbatches from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer from .global_vars import get_tokenizer
......
...@@ -326,6 +326,8 @@ def _add_training_args(parser): ...@@ -326,6 +326,8 @@ def _add_training_args(parser):
group.add_argument('--exit-interval', type=int, default=None, group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible ' help='Exit the program after the iteration is divisible '
'by this value.') 'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None, group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.') help='Write TensorBoard logs to this directory.')
group.add_argument('--scaled-upper-triang-masked-softmax-fusion', group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
......
...@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -418,11 +418,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, dataset_type=dataset_type)
if train_ds:
# Blend. train_datasets.append(train_ds)
blending_train_dataset = BlendableDataset(train_datasets, weights) if valid_ds:
blending_valid_dataset = BlendableDataset(valid_datasets, weights) valid_datasets.append(valid_ds)
blending_test_dataset = BlendableDataset(test_datasets, weights) if test_ds:
test_datasets.append(test_ds)
# Blend.
blending_train_dataset = None
if train_datasets:
blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
......
...@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -55,14 +55,23 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
seq_length, seed, skip_warmup) seq_length, seed, skip_warmup)
train_datasets.append(train_ds) if train_ds:
valid_datasets.append(valid_ds) train_datasets.append(train_ds)
test_datasets.append(test_ds) if valid_ds:
valid_datasets.append(valid_ds)
if test_ds:
test_datasets.append(test_ds)
# Blend. # Blend.
blending_train_dataset = BlendableDataset(train_datasets, weights) blending_train_dataset = None
blending_valid_dataset = BlendableDataset(valid_datasets, weights) if train_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights) blending_train_dataset = BlendableDataset(train_datasets, weights)
blending_valid_dataset = None
if valid_datasets:
blending_valid_dataset = BlendableDataset(valid_datasets, weights)
blending_test_dataset = None
if test_datasets:
blending_test_dataset = BlendableDataset(test_datasets, weights)
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
......
...@@ -43,8 +43,13 @@ def get_num_microbatches(): ...@@ -43,8 +43,13 @@ def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def update_num_microbatches(consumed_samples): def get_current_global_batch_size():
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples) return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
def get_tokenizer(): def get_tokenizer():
......
...@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC): ...@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):
def __init__(self): def __init__(self):
self.num_micro_batches = None self.num_micro_batches = None
self.current_global_batch_size = None
def get(self): def get(self):
return self.num_micro_batches return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod @abstractmethod
def update(self, consumed_samples): def update(self, consumed_samples, consistency_check):
pass pass
...@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator): ...@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
self.num_micro_batches = global_batch_size // \ self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel micro_batch_times_data_parallel
assert self.num_micro_batches >= 1 assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples): def update(self, consumed_samples, consistency_check):
pass pass
...@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self.rampup_samples_per_increment = self.ramup_samples / num_increments self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches. # Initialize number of microbatches.
self.update(0) self.update(0, False)
def update(self, consumed_samples): def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples: if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size self.current_global_batch_size = self.global_batch_size
else: else:
steps = int(consumed_samples / self.rampup_samples_per_increment) steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \ self.current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size assert self.current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \ if consistency_check:
self.micro_batch_times_data_parallel_size == 0, 'current global ' \ assert self.current_global_batch_size % \
'batch size ({}) is not divisible by micro-batch-size ({}) times' \ self.micro_batch_times_data_parallel_size == 0, 'current global ' \
'data parallel size ({})'.format(current_global_batch_size, 'batch size ({}) is not divisible by micro-batch-size ({}) times' \
self.micro_batch_size, 'data parallel size ({})'.format(self.current_global_batch_size,
self.data_parallel_size) self.micro_batch_size,
self.num_micro_batches = current_global_batch_size // \ self.data_parallel_size)
self.num_micro_batches = self.current_global_batch_size // \
self.micro_batch_times_data_parallel_size self.micro_batch_times_data_parallel_size
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
from datetime import datetime from datetime import datetime
import math import math
import sys import sys
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
...@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam ...@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import get_current_global_batch_size
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import update_num_microbatches from megatron import update_num_microbatches
from megatron import mpu from megatron import mpu
...@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader ...@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print_rank_0('[' + string + '] datetime: {} '.format(time_str))
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}): forward_step_func, extra_args_provider=None, args_defaults={}):
"""Main training program. """Main training program.
...@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron(extra_args_provider=extra_args_provider, initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
print_rank_0('time took to initialize megatron (seconds): {:.3f}'.format(
time.time() - _TRAIN_START_TIME))
print_datetime('after megatron is initialized')
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers('model and optimizer').stop() timers('model and optimizer').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
# Data stuff. # Data stuff.
timers('train/valid/test data iterators').start() timers('train/valid/test data iterators').start()
...@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider) train_valid_test_dataset_provider)
timers('train/valid/test data iterators').stop() timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are build')
# Print setup timing. # Print setup timing.
print_rank_0('done with setups ...') print_rank_0('done with setups ...')
...@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = train(forward_step_func, iteration = train(forward_step_func,
model, optimizer, lr_scheduler, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator) train_data_iterator, valid_data_iterator)
print_datetime('after training is done')
if args.do_valid: if args.do_valid:
prefix = 'the end of training for val data' prefix = 'the end of training for val data'
...@@ -132,13 +160,11 @@ def update_train_iters(args): ...@@ -132,13 +160,11 @@ def update_train_iters(args):
consumed_samples = 0 consumed_samples = 0
# Rampup phase. # Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]): while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples) update_num_microbatches(consumed_samples, consistency_check=False)
consumed_samples += get_num_microbatches() * \ consumed_samples += get_current_global_batch_size()
args.micro_batch_size * \
args.data_parallel_size
iterations += 1 iterations += 1
# Reset # Reset
update_num_microbatches(0) update_num_microbatches(0, consistency_check=False)
# Constant phase # Constant phase
# Note that we throw away any partial last batch. # Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \ iterations += (args.train_samples - consumed_samples) // \
...@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None: if args.load is not None:
timers = get_timers()
# Extra barrier is added to make sure all ranks report the
# max time.
torch.distributed.barrier()
timers('load checkpoint').start()
args.iteration = load_checkpoint(model, optimizer, lr_scheduler) args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('load checkpoint').stop()
timers.log(['load checkpoint'])
else: else:
args.iteration = 0 args.iteration = 0
...@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Tensorboard values. # Tensorboard values.
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration) writer.add_scalar('learning_rate-iterations', learning_rate, iteration)
writer.add_scalar('learning_rate-samples', learning_rate,
args.consumed_train_samples)
batch_size = args.micro_batch_size * args.data_parallel_size * \
get_num_microbatches()
writer.add_scalar('batch_size-iterations', batch_size, iteration)
writer.add_scalar('batch_size-samples', batch_size,
args.consumed_train_samples)
for key in loss_dict: for key in loss_dict:
writer.add_scalar(key, loss_dict[key], iteration) writer.add_scalar(key, loss_dict[key] + '-iterations', iteration)
writer.add_scalar(key, loss_dict[key] + '-samples',
args.consumed_train_samples)
if args.fp16: if args.fp16:
writer.add_scalar('loss_scale', loss_scale, iteration) writer.add_scalar('loss_scale-iterations', loss_scale, iteration)
writer.add_scalar('loss_scale-samples', loss_scale,
args.consumed_train_samples)
normalizer = iteration % args.log_interval normalizer = iteration % args.log_interval
if normalizer == 0: if normalizer == 0:
normalizer = args.log_interval normalizer = args.log_interval
...@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
elapsed_time / args.log_interval, iteration) elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format( log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters) iteration, args.train_iters)
log_string += ' consumed samples {:12d} |'.format(
args.consumed_train_samples)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval) elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' learning rate: {:.3E} |'.format(learning_rate)
...@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return report_memory_flag return report_memory_flag
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
timers = get_timers()
# Extra barrier is added to make sure
# all ranks report the max time.
torch.distributed.barrier()
timers('save checkpoint').start()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
timers('save checkpoint').stop()
timers.log(['save checkpoint'])
def train(forward_step_func, model, optimizer, lr_scheduler, def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator): train_data_iterator, valid_data_iterator):
"""Train the model function.""" """Train the model function."""
...@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration = args.iteration iteration = args.iteration
timers('interval time').start() timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
...@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler) lr_scheduler)
# Checkpointing # Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator, model, valid_data_iterator, model,
iteration, False) iteration, False)
# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
done_cuda = torch.cuda.IntTensor(
[train_time > args.exit_duration_in_mins])
torch.distributed.all_reduce(
done_cuda, op=torch.distributed.ReduceOp.MAX)
done = done_cuda.item()
if done:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))
sys.exit()
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') print_datetime('exiting program at iteration {}'.format(iteration))
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
sys.exit() sys.exit()
return iteration return iteration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment