Commit 9321d5c6 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Change lr-warmup-percent to lr-warmup-fraction

parent 0c151638
......@@ -137,9 +137,9 @@ def parse_args(extra_args_provider=None, defaults={},
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_percent is not None:
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-iters'
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
......@@ -151,9 +151,9 @@ def parse_args(extra_args_provider=None, defaults={},
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_percent is not None:
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-samples'
'can only specify one of lr-warmup-fraction and lr-warmup-samples'
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
......@@ -359,9 +359,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-percent', type=float, default=None,
help='percentage of lr-warmup-(iters/samples) to use '
'for warmup')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
......
......@@ -224,8 +224,8 @@ def get_learning_rate_scheduler(optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training.
......@@ -237,8 +237,8 @@ def get_learning_rate_scheduler(optimizer):
if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples
if args.lr_warmup_percent is not None:
warmup_steps = args.lr_warmup_percent * decay_steps
if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_fraction * decay_steps
else:
warmup_steps = args.lr_warmup_samples
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment