Commit da5860f2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support Lamb optimizer within BERT.

PiperOrigin-RevId: 303356961
parent 0265f59c
......@@ -76,6 +76,8 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags_core.define_log_steps()
......
......@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......
......@@ -105,7 +105,8 @@ def run_customized_training(strategy,
pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq)
optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps)
initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer(
optimizer,
use_float16=common_flags.use_float16(),
......
......@@ -244,7 +244,8 @@ def train_squad(strategy,
hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs,
warmup_steps)
warmup_steps,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer(
optimizer,
......
......@@ -20,7 +20,9 @@ from __future__ import print_function
import re
from absl import logging
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
......@@ -65,7 +67,8 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
......@@ -76,13 +79,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
if optimizer_type == 'adamw':
logging.info('using Adamw optimizer')
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
else:
raise ValueError('Unsupported optimizer type: ', optimizer_type)
return optimizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment