Commit da5860f2 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Support Lamb optimizer within BERT.

PiperOrigin-RevId: 303356961
parent 0265f59c
...@@ -76,6 +76,8 @@ def define_common_bert_flags(): ...@@ -76,6 +76,8 @@ def define_common_bert_flags():
'If specified, init_checkpoint flag should not be used.') 'If specified, init_checkpoint flag should not be used.')
flags.DEFINE_bool('hub_module_trainable', True, flags.DEFINE_bool('hub_module_trainable', True,
'True to make keras layers in the hub module trainable.') 'True to make keras layers in the hub module trainable.')
flags.DEFINE_string('optimizer_type', 'adamw',
'The type of optimizer to use for training (adamw|lamb)')
flags_core.define_log_steps() flags_core.define_log_steps()
......
...@@ -125,7 +125,8 @@ def run_bert_classifier(strategy, ...@@ -125,7 +125,8 @@ def run_bert_classifier(strategy,
hub_module_url=FLAGS.hub_module_url, hub_module_url=FLAGS.hub_module_url,
hub_module_trainable=FLAGS.hub_module_trainable)) hub_module_trainable=FLAGS.hub_module_trainable))
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
classifier_model.optimizer = performance.configure_optimizer( classifier_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
......
...@@ -105,7 +105,8 @@ def run_customized_training(strategy, ...@@ -105,7 +105,8 @@ def run_customized_training(strategy,
pretrain_model, core_model = bert_models.pretrain_model( pretrain_model, core_model = bert_models.pretrain_model(
bert_config, max_seq_length, max_predictions_per_seq) bert_config, max_seq_length, max_predictions_per_seq)
optimizer = optimization.create_optimizer( optimizer = optimization.create_optimizer(
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps,
FLAGS.optimizer_type)
pretrain_model.optimizer = performance.configure_optimizer( pretrain_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=common_flags.use_float16(), use_float16=common_flags.use_float16(),
......
...@@ -244,7 +244,8 @@ def train_squad(strategy, ...@@ -244,7 +244,8 @@ def train_squad(strategy,
hub_module_trainable=FLAGS.hub_module_trainable) hub_module_trainable=FLAGS.hub_module_trainable)
optimizer = optimization.create_optimizer(FLAGS.learning_rate, optimizer = optimization.create_optimizer(FLAGS.learning_rate,
steps_per_epoch * epochs, steps_per_epoch * epochs,
warmup_steps) warmup_steps,
FLAGS.optimizer_type)
squad_model.optimizer = performance.configure_optimizer( squad_model.optimizer = performance.configure_optimizer(
optimizer, optimizer,
......
...@@ -20,7 +20,9 @@ from __future__ import print_function ...@@ -20,7 +20,9 @@ from __future__ import print_function
import re import re
from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
...@@ -65,7 +67,8 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -65,7 +67,8 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
} }
def create_optimizer(init_lr, num_train_steps, num_warmup_steps): def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule.""" """Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate. # Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
...@@ -76,13 +79,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps): ...@@ -76,13 +79,28 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
learning_rate_fn = WarmUp(initial_learning_rate=init_lr, learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn, decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps) warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn, if optimizer_type == 'adamw':
weight_decay_rate=0.01, logging.info('using Adamw optimizer')
beta_1=0.9, optimizer = AdamWeightDecay(
beta_2=0.999, learning_rate=learning_rate_fn,
epsilon=1e-6, weight_decay_rate=0.01,
exclude_from_weight_decay=['layer_norm', 'bias']) beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
else:
raise ValueError('Unsupported optimizer type: ', optimizer_type)
return optimizer return optimizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment