Commit 22ab91bb authored by mohammad's avatar mohammad Committed by Deepak Narayanan
Browse files

Sample based learning rate computation

parent 6a68502d
...@@ -125,6 +125,30 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -125,6 +125,30 @@ def parse_args(extra_args_provider=None, defaults={},
else: else:
setattr(args, key, defaults[key]) setattr(args, key, defaults[key])
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learnig rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings'] 'max_position_embeddings']
...@@ -269,7 +293,12 @@ def _add_training_args(parser): ...@@ -269,7 +293,12 @@ def _add_training_args(parser):
help='chunk size (number of layers) for checkpointing.') help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None, group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all ' help='Total number of iterations to train over all '
'training runs.') 'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.') help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None, group.add_argument('--exit-interval', type=int, default=None,
...@@ -319,12 +348,18 @@ def _add_learning_rate_args(parser): ...@@ -319,12 +348,18 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-iters', type=int, default=None, group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,' help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`') ' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--min-lr', type=float, default=0.0, group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler' help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.') 'clip values below this threshold.')
group.add_argument('--warmup', type=float, default=0.01,
help='Percentage of total iterations to warmup on '
'(.01 = 1 percent of all training iters).')
group.add_argument('--override-lr-scheduler', action='store_true', group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,' help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum ' 'warmup iterations, minimum learning rate, maximum '
......
...@@ -106,20 +106,12 @@ def _build_num_microbatches_calculator(args): ...@@ -106,20 +106,12 @@ def _build_num_microbatches_calculator(args):
# Constant num micro-batches. # Constant num micro-batches.
if args.rampup_batch_size is None: if args.rampup_batch_size is None:
micro_batch_times_data_parallel = args.micro_batch_size * \ _GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
args.data_parallel_size args.global_batch_size, args.micro_batch_size,
assert args.global_batch_size % micro_batch_times_data_parallel == 0, \ args.data_parallel_size)
'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size)
num_micro_batches = args.global_batch_size // \
micro_batch_times_data_parallel
if args.rank == 0: if args.rank == 0:
print('setting number of micro-batches to constant {}'.format( print('setting number of micro-batches to constant {}'.format(
num_micro_batches), flush=True) _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()), flush=True)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = ConstantNumMicroBatches(
num_micro_batches)
else: else:
assert len(args.rampup_batch_size) == 3, 'expected the following ' \ assert len(args.rampup_batch_size) == 3, 'expected the following ' \
...@@ -143,10 +135,8 @@ def _build_num_microbatches_calculator(args): ...@@ -143,10 +135,8 @@ def _build_num_microbatches_calculator(args):
class NumMicroBatchesCalculator(ABC): class NumMicroBatchesCalculator(ABC):
def __init__(self, name): def __init__(self):
self.name = name
self.num_micro_batches = None self.num_micro_batches = None
super(NumMicroBatchesCalculator, self).__init__()
def get(self): def get(self):
return self.num_micro_batches return self.num_micro_batches
...@@ -158,11 +148,17 @@ class NumMicroBatchesCalculator(ABC): ...@@ -158,11 +148,17 @@ class NumMicroBatchesCalculator(ABC):
class ConstantNumMicroBatches(NumMicroBatchesCalculator): class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, num_micro_batches=1): def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
super(ConstantNumMicroBatches, self).__init__( micro_batch_times_data_parallel = micro_batch_size * \
'constant: {}'.format(num_micro_batches)) data_parallel_size
assert num_micro_batches >= 1 assert global_batch_size % micro_batch_times_data_parallel == 0, \
self.num_micro_batches = num_micro_batches 'global batch size ({}) is not divisible by micro batch size ({})' \
' times data parallel size ({})'.format(global_batch_size,
micro_batch_size,
data_parallel_size)
self.num_micro_batches = global_batch_size // \
micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
def update(self, consumed_samples): def update(self, consumed_samples):
pass pass
...@@ -188,10 +184,6 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -188,10 +184,6 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
data_parallel_size: data parallel size. data_parallel_size: data parallel size.
""" """
super(RampupBatchsizeNumMicroBatches, self).__init__(
'batch size ramup: {}, {}, {}'.format(
start_batch_size, batch_size_increment, ramup_samples))
self.micro_batch_size = micro_batch_size self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * \ self.micro_batch_times_data_parallel_size = self.micro_batch_size * \
...@@ -212,8 +204,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -212,8 +204,9 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
'size increment ({})'.format(diff_batch_size, batch_size_increment) 'size increment ({})'.format(diff_batch_size, batch_size_increment)
num_increments = diff_batch_size // self.batch_size_increment num_increments = diff_batch_size // self.batch_size_increment
assert ramup_samples >= 0 self.ramup_samples = ramup_samples
self.rampup_samples_per_increment = ramup_samples / num_increments assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches. # Initialize number of microbatches.
self.update(0) self.update(0)
...@@ -221,11 +214,13 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): ...@@ -221,11 +214,13 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def update(self, consumed_samples): def update(self, consumed_samples):
steps = int(consumed_samples / self.rampup_samples_per_increment) if consumed_samples > self.ramup_samples:
current_global_batch_size = self.start_batch_size + \ current_global_batch_size = self.global_batch_size
steps * self.batch_size_increment else:
current_global_batch_size = min(current_global_batch_size, steps = int(consumed_samples / self.rampup_samples_per_increment)
self.global_batch_size) current_global_batch_size = self.start_batch_size + \
steps * self.batch_size_increment
assert current_global_batch_size <= self.global_batch_size
assert current_global_batch_size % \ assert current_global_batch_size % \
self.micro_batch_times_data_parallel_size == 0, 'current global ' \ self.micro_batch_times_data_parallel_size == 0, 'current global ' \
......
...@@ -23,8 +23,7 @@ class AnnealingLR(object): ...@@ -23,8 +23,7 @@ class AnnealingLR(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
def __init__(self, optimizer, max_lr, min_lr, def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, warmup_steps, decay_steps, decay_style,
decay_style, num_steps,
use_checkpoint_lr_scheduler=True, use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False): override_lr_scheduler=False):
...@@ -37,7 +36,7 @@ class AnnealingLR(object): ...@@ -37,7 +36,7 @@ class AnnealingLR(object):
assert self.max_lr >= self.min_lr assert self.max_lr >= self.min_lr
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
self.num_steps = num_steps self.num_steps = 0
self.decay_steps = decay_steps self.decay_steps = decay_steps
assert self.decay_steps > 0 assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps assert self.warmup_steps < self.decay_steps
...@@ -51,7 +50,7 @@ class AnnealingLR(object): ...@@ -51,7 +50,7 @@ class AnnealingLR(object):
'use-checkpoint are set.' 'use-checkpoint are set.'
# Set the learning rate # Set the learning rate
self.step(step_num=self.num_steps) self.step(0)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
...@@ -92,11 +91,9 @@ class AnnealingLR(object): ...@@ -92,11 +91,9 @@ class AnnealingLR(object):
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
def step(self, increment=1, step_num=None): def step(self, increment):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: self.num_steps += increment
step_num = self.num_steps + increment
self.num_steps = step_num
new_lr = self.get_lr() new_lr = self.get_lr()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
...@@ -160,7 +157,7 @@ class AnnealingLR(object): ...@@ -160,7 +157,7 @@ class AnnealingLR(object):
'decay style') 'decay style')
if 'num_iters' in sd: if 'num_iters' in sd:
self.num_steps = sd['num_iters'] num_steps = sd['num_iters']
else: else:
self.num_steps = sd['num_steps'] num_steps = sd['num_steps']
self.step(step_num=self.num_steps) self.step(increment=num_steps)
...@@ -116,6 +116,37 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -116,6 +116,37 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
test_data_iterator, model, test_data_iterator, model,
0, True) 0, True)
def update_train_iters(args):
# For iteration-based training, we don't need to do anything
if args.train_iters:
return
# Constant batch size with sample-based training.
if args.rampup_batch_size is None:
args.train_iters = args.train_samples // args.global_batch_size
else:
# Sample based training with rampup batch size.
iterations = 0
consumed_samples = 0
# Rampup phase.
while consumed_samples <= int(args.rampup_batch_size[2]):
update_num_microbatches(consumed_samples)
consumed_samples += get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
iterations += 1
# Reset
update_num_microbatches(0)
# Constant phase
# Note that we throw away any partial last batch.
iterations += (args.train_samples - consumed_samples) // \
args.global_batch_size
args.train_iters = iterations
print_rank_0('setting training iterations to {}'.format(args.train_iters))
def get_model(model_provider_func): def get_model(model_provider_func):
"""Build the model.""" """Build the model."""
...@@ -188,22 +219,33 @@ def get_learning_rate_scheduler(optimizer): ...@@ -188,22 +219,33 @@ def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
# Add linear learning rate scheduler. # Iteration-based training.
if args.lr_decay_iters is not None: if args.train_iters:
num_iters = args.lr_decay_iters if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
warmup_steps = args.lr_warmup_iters * args.global_batch_size
decay_steps = args.lr_decay_iters * args.global_batch_size
# Sample-based training.
elif args.train_samples:
# We need to set training iters for later use. Technically
# we need to adjust the training samples too (due to last
# batch being incomplete) but we leave it as is for now.
update_train_iters(args)
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
warmup_steps = args.lr_warmup_samples
decay_steps = args.lr_decay_samples
else: else:
num_iters = args.train_iters raise Exception(
num_iters = max(1, num_iters) 'either train-iters or train-samples should be provided.')
init_step = 0
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR( lr_scheduler = AnnealingLR(
optimizer, optimizer,
max_lr=args.lr, max_lr=args.lr,
min_lr=args.min_lr, min_lr=args.min_lr,
warmup_steps=warmup_iter, warmup_steps=warmup_steps,
decay_steps=num_iters, decay_steps=decay_steps,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
num_steps=init_step,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
...@@ -568,7 +610,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -568,7 +610,10 @@ def train_step(forward_step_func, data_iterator,
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
if not (args.fp16 and optimizer.overflow): if not (args.fp16 and optimizer.overflow):
lr_scheduler.step() increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
lr_scheduler.step(increment=increment)
else: else:
skipped_iter = 1 skipped_iter = 1
...@@ -649,8 +694,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -649,8 +694,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if writer and torch.distributed.get_rank() == 0: if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time', writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration) elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration, log_string = ' iteration {:8d}/{:8d} |'.format(
args.train_iters) iteration, args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval) elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' learning rate: {:.3E} |'.format(learning_rate)
...@@ -837,8 +882,12 @@ def build_train_valid_test_data_iterators( ...@@ -837,8 +882,12 @@ def build_train_valid_test_data_iterators(
# Backward compatibility, assume fixed batch size. # Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0: if args.iteration > 0 and args.consumed_train_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0: if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size args.eval_iters * args.global_batch_size
...@@ -846,10 +895,14 @@ def build_train_valid_test_data_iterators( ...@@ -846,10 +895,14 @@ def build_train_valid_test_data_iterators(
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters if args.train_samples:
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters train_samples = args.train_samples
else:
train_samples = args.train_iters * args.global_batch_size
eval_iters = (args.train_iters // args.eval_interval + 1) * \
args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * args.global_batch_size, train_val_test_num_samples = [train_samples,
eval_iters * args.global_batch_size, eval_iters * args.global_batch_size,
test_iters * args.global_batch_size] test_iters * args.global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' > datasets target sizes (minimum size):')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment