Commit 22ab91bb authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Sample based learning rate computation

parent 6a68502d
......@@ -125,6 +125,30 @@ def parse_args(extra_args_provider=None, defaults={},
else:
setattr(args, key, defaults[key])
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learnig rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
......@@ -269,7 +293,12 @@ def _add_training_args(parser):
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs.')
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
......@@ -319,12 +348,18 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='Percentage of total iterations to warmup on '
'(.01 = 1 percent of all training iters).')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
......
......@@ -106,20 +106,12 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
micro_batch_times_data_parallel = args.micro_batch_size * \
args.data_parallel_size
assert args.global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size)
num_micro_batches = args.global_batch_size // \
micro_batch_times_data_parallel
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size,
args.data_parallel_size)
if args.rank == 0:
print('setting number of micro-batches to constant {}'.format(
num_micro_batches), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()), flush=True)
else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \
......@@ -143,10 +135,8 @@ def _build_num_microbatches_calculator(args):
class NumMicroBatchesCalculator(ABC):
def __init__(self, name):
self.name = name
def __init__(self):
self.num_micro_batches = None
super(NumMicroBatchesCalculator, self).__init__()
def get(self):
return self.num_micro_batches
......@@ -158,11 +148,17 @@ class NumMicroBatchesCalculator(ABC):
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, num_micro_batches=1):
super(ConstantNumMicroBatches, self).__init__(
'constant: {}'.format(num_micro_batches))
assert num_micro_batches >= 1
self.num_micro_batches = num_micro_batches
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * \
data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, \
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
def update(self, consumed_samples):
pass
......@@ -188,10 +184,6 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size.
"""
super(RampupBatchsizeNumMicroBatches, self).__init__(
'batch size ramup: {}, {}, {}'.format(
start_batch_size, batch_size_increment, ramup_samples))
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
......@@ -212,8 +204,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment
assert ramup_samples >= 0
self.rampup_samples_per_increment = ramup_samples / num_increments
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0)
......@@ -221,11 +214,13 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def update(self, consumed_samples):
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
current_global_batch_size = min(current_global_batch_size,
self.global_batch_size)
if consumed_samples > self.ramup_samples:
current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \
......
......@@ -23,8 +23,7 @@ class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps,
decay_style, num_steps,
warmup_steps, decay_steps, decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
......@@ -37,7 +36,7 @@ class AnnealingLR(object):
assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps
self.num_steps = num_steps
self.num_steps = 0
self.decay_steps = decay_steps
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps
......@@ -51,7 +50,7 @@ class AnnealingLR(object):
'use-checkpoint are set.'
# Set the learning rate
self.step(step_num=self.num_steps)
self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
......@@ -92,11 +91,9 @@ class AnnealingLR(object):
return self.min_lr + coeff * delta_lr
def step(self, increment=1, step_num=None):
def step(self, increment):
"""Set lr for all parameters groups."""
if step_num is None:
step_num = self.num_steps + increment
self.num_steps = step_num
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
......@@ -160,7 +157,7 @@ class AnnealingLR(object):
'decay style')
if 'num_iters' in sd:
self.num_steps = sd['num_iters']
num_steps = sd['num_iters']
else:
self.num_steps = sd['num_steps']
self.step(step_num=self.num_steps)
num_steps = sd['num_steps']
self.step(increment=num_steps)
......@@ -116,6 +116,37 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
test_data_iterator, model,
0, True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples)
consumed_samples += get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
iterations += 1
# Reset
update_num_microbatches(0)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func):
"""Build the model."""
......@@ -188,22 +219,33 @@ def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler."""
args = get_args()
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
# Iteration-based training.
if args.train_iters:
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
warmup_steps = args.lr_warmup_iters * args.global_batch_size
decay_steps = args.lr_decay_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
warmup_steps = args.lr_warmup_samples
decay_steps = args.lr_decay_samples
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
init_step = 0
warmup_iter = args.warmup * num_iters
raise Exception(
'either train-iters or train-samples should be provided.')
lr_scheduler = AnnealingLR(
optimizer,
max_lr=args.lr,
min_lr=args.min_lr,
warmup_steps=warmup_iter,
decay_steps=num_iters,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
......@@ -568,7 +610,10 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
else:
skipped_iter = 1
......@@ -649,8 +694,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string = ' iteration {:8d}/{:8d} |'.format(
iteration, args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
......@@ -837,8 +882,12 @@ def build_train_valid_test_data_iterators(
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
......@@ -846,10 +895,14 @@ def build_train_valid_test_data_iterators(
if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
if args.train_samples:
train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * args.global_batch_size,
train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size,
test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment