Commit 22de957a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add end learning rate (end_lr) FLAG for linear learning rate decay.

PiperOrigin-RevId: 305948522
parent 9b6507fe
...@@ -63,6 +63,10 @@ def define_common_bert_flags(): ...@@ -63,6 +63,10 @@ def define_common_bert_flags():
'inside.') 'inside.')
flags.DEFINE_float('learning_rate', 5e-5, flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.') 'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean( flags.DEFINE_boolean(
'scale_loss', False, 'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica ' 'Whether to divide the loss by number of replica inside the per-replica '
...@@ -76,8 +80,6 @@ def define_common_bert_flags(): ...@@ -76,8 +80,6 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.') 'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags_core.define_log_steps() flags_core.define_log_steps()
......
...@@ -126,7 +126,7 @@ def run_bert_classifier(strategy, ...@@ -126,7 +126,7 @@ def run_bert_classifier(strategy,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps, initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type) FLAGS.end_lr, FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
......
...@@ -92,6 +92,8 @@ def run_customized_training(strategy, ...@@ -92,6 +92,8 @@ def run_customized_training(strategy,
epochs, epochs,
initial_lr, initial_lr,
warmup_steps, warmup_steps,
end_lr,
optimizer_type,
input_files, input_files,
train_batch_size): train_batch_size):
"""Run BERT pretrain model training using low-level API.""" """Run BERT pretrain model training using low-level API."""
...@@ -106,7 +108,7 @@ def run_customized_training(strategy, ...@@ -106,7 +108,7 @@ def run_customized_training(strategy,
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq)
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps, initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type) end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
...@@ -152,6 +154,8 @@ def run_bert_pretrain(strategy): ...@@ -152,6 +154,8 @@ def run_bert_pretrain(strategy):
FLAGS.num_train_epochs, FLAGS.num_train_epochs,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.warmup_steps, FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files, FLAGS.input_files,
FLAGS.train_batch_size) FLAGS.train_batch_size)
......
...@@ -251,6 +251,7 @@ def train_squad(strategy, ...@@ -251,6 +251,7 @@ def train_squad(strategy,
optimizer = optimization.create_optimizer(FLAGS.learning_rate, optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs, steps_per_epoch * epochs,
warmup_steps, warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type) FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
......
...@@ -70,13 +70,14 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -70,13 +70,14 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
def create_optimizer(init_lr, def create_optimizer(init_lr,
num_train_steps, num_train_steps,
num_warmup_steps, num_warmup_steps,
end_lr=0.0,
optimizer_type='adamw'): optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule.""" """Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate. # Implements linear decay of the learning rate.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay( lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, initial_learning_rate=init_lr,
decay_steps=num_train_steps, decay_steps=num_train_steps,
end_learning_rate=0.0) end_learning_rate=end_lr)
if num_warmup_steps: if num_warmup_steps:
lr_schedule = WarmUp( lr_schedule = WarmUp(
initial_learning_rate=init_lr, initial_learning_rate=init_lr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment