Commit 9321d5c6 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Change lr-warmup-percent to lr-warmup-fraction

parent 0c151638
...@@ -137,9 +137,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -137,9 +137,9 @@ def parse_args(extra_args_provider=None, defaults={},
'expected iteration-based learning rate warmup' 'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \ assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training' 'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_percent is not None: if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \ assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-iters' 'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training. # Sample-based training.
if args.train_samples: if args.train_samples:
...@@ -151,9 +151,9 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -151,9 +151,9 @@ def parse_args(extra_args_provider=None, defaults={},
'expected sample-based learning rate decay' 'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \ assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup' 'expected sample-based learnig rate warmup'
if args.lr_warmup_percent is not None: if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \ assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-percent and lr-warmup-samples' 'can only specify one of lr-warmup-fraction and lr-warmup-samples'
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
...@@ -359,9 +359,9 @@ def _add_learning_rate_args(parser): ...@@ -359,9 +359,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None, group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,' help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`') ' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-percent', type=float, default=None, group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='percentage of lr-warmup-(iters/samples) to use ' help='fraction of lr-warmup-(iters/samples) to use '
'for warmup') 'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0, group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup ' help='number of iterations to linearly warmup '
'learning rate over.') 'learning rate over.')
......
...@@ -224,8 +224,8 @@ def get_learning_rate_scheduler(optimizer): ...@@ -224,8 +224,8 @@ def get_learning_rate_scheduler(optimizer):
if args.lr_decay_iters is None: if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters args.lr_decay_iters = args.train_iters
decay_steps = args.lr_decay_iters * args.global_batch_size decay_steps = args.lr_decay_iters * args.global_batch_size
if args.lr_warmup_percent is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_percent * decay_steps warmup_steps = args.lr_warmup_fraction * decay_steps
else: else:
warmup_steps = args.lr_warmup_iters * args.global_batch_size warmup_steps = args.lr_warmup_iters * args.global_batch_size
# Sample-based training. # Sample-based training.
...@@ -237,8 +237,8 @@ def get_learning_rate_scheduler(optimizer): ...@@ -237,8 +237,8 @@ def get_learning_rate_scheduler(optimizer):
if args.lr_decay_samples is None: if args.lr_decay_samples is None:
args.lr_decay_samples = args.train_samples args.lr_decay_samples = args.train_samples
decay_steps = args.lr_decay_samples decay_steps = args.lr_decay_samples
if args.lr_warmup_percent is not None: if args.lr_warmup_fraction is not None:
warmup_steps = args.lr_warmup_percent * decay_steps warmup_steps = args.lr_warmup_fraction * decay_steps
else: else:
warmup_steps = args.lr_warmup_samples warmup_steps = args.lr_warmup_samples
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment