Commit 22de957a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add end learning rate (end_lr) FLAG for linear learning rate decay.

PiperOrigin-RevId: 305948522
parent 9b6507fe
......@@ -63,6 +63,10 @@ def define_common_bert_flags():
'inside.')
flags.DEFINE_float('learning_rate', 5e-5,
'The initial learning rate for Adam.')
flags.DEFINE_float('end_lr', 0.0,
'The end learning rate for learning rate decay.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
......@@ -76,8 +80,6 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags_core.define_log_steps()
......
......@@ -126,7 +126,7 @@ def run_bert_classifier(strategy,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
FLAGS.end_lr, FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......
......@@ -92,6 +92,8 @@ def run_customized_training(strategy,
epochs,
initial_lr,
warmup_steps,
end_lr,
optimizer_type,
input_files,
train_batch_size):
"""Run BERT pretrain model training using low-level API."""
......@@ -106,7 +108,7 @@ def run_customized_training(strategy,
bert_config, max_seq_length, max_predictions_per_seq)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
end_lr, optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......@@ -152,6 +154,8 @@ def run_bert_pretrain(strategy):
FLAGS.num_train_epochs,
FLAGS.learning_rate,
FLAGS.warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type,
FLAGS.input_files,
FLAGS.train_batch_size)
......
......@@ -251,6 +251,7 @@ def train_squad(strategy,
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps,
FLAGS.end_lr,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer(
......
......@@ -70,13 +70,14 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
end_lr=0.0,
optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
end_learning_rate=end_lr)
if num_warmup_steps:
lr_schedule = WarmUp(
initial_learning_rate=init_lr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment