Commit 1aaab0b7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

lr_schedule is a more reasonable name. It is a callable, not a function.

PiperOrigin-RevId: 303840099
parent 39f05f0e
...@@ -28,13 +28,12 @@ import tensorflow_addons.optimizers as tfa_optimizers ...@@ -28,13 +28,12 @@ import tensorflow_addons.optimizers as tfa_optimizers
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applies a warmup schedule on a given learning rate decay schedule.""" """Applies a warmup schedule on a given learning rate decay schedule."""
def __init__( def __init__(self,
self, initial_learning_rate,
initial_learning_rate, decay_schedule_fn,
decay_schedule_fn, warmup_steps,
warmup_steps, power=1.0,
power=1.0, name=None):
name=None):
super(WarmUp, self).__init__() super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps self.warmup_steps = warmup_steps
...@@ -52,10 +51,11 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -52,10 +51,11 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
warmup_learning_rate = ( warmup_learning_rate = (
self.initial_learning_rate * self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power)) tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float, return tf.cond(
lambda: warmup_learning_rate, global_step_float < warmup_steps_float,
lambda: self.decay_schedule_fn(step), lambda: warmup_learning_rate,
name=name) lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self): def get_config(self):
return { return {
...@@ -67,23 +67,26 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -67,23 +67,26 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
} }
def create_optimizer(init_lr, num_train_steps, num_warmup_steps, def create_optimizer(init_lr,
num_train_steps,
num_warmup_steps,
optimizer_type='adamw'): optimizer_type='adamw'):
"""Creates an optimizer with learning rate schedule.""" """Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate. # Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr, initial_learning_rate=init_lr,
decay_steps=num_train_steps, decay_steps=num_train_steps,
end_learning_rate=0.0) end_learning_rate=0.0)
if num_warmup_steps: if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr, lr_schedule = WarmUp(
decay_schedule_fn=learning_rate_fn, initial_learning_rate=init_lr,
warmup_steps=num_warmup_steps) decay_schedule_fn=lr_schedule,
warmup_steps=num_warmup_steps)
if optimizer_type == 'adamw': if optimizer_type == 'adamw':
logging.info('using Adamw optimizer') logging.info('using Adamw optimizer')
optimizer = AdamWeightDecay( optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn, learning_rate=lr_schedule,
weight_decay_rate=0.01, weight_decay_rate=0.01,
beta_1=0.9, beta_1=0.9,
beta_2=0.999, beta_2=0.999,
...@@ -92,7 +95,7 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps, ...@@ -92,7 +95,7 @@ def create_optimizer(init_lr, num_train_steps, num_warmup_steps,
elif optimizer_type == 'lamb': elif optimizer_type == 'lamb':
logging.info('using Lamb optimizer') logging.info('using Lamb optimizer')
optimizer = tfa_optimizers.LAMB( optimizer = tfa_optimizers.LAMB(
learning_rate=learning_rate_fn, learning_rate=lr_schedule,
weight_decay_rate=0.01, weight_decay_rate=0.01,
beta_1=0.9, beta_1=0.9,
beta_2=0.999, beta_2=0.999,
...@@ -127,8 +130,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -127,8 +130,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
exclude_from_weight_decay=None, exclude_from_weight_decay=None,
name='AdamWeightDecay', name='AdamWeightDecay',
**kwargs): **kwargs):
super(AdamWeightDecay, self).__init__( super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay self._exclude_from_weight_decay = exclude_from_weight_decay
...@@ -189,15 +192,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -189,15 +192,15 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense( return super(AdamWeightDecay,
grad, var, **kwargs) self)._resource_apply_dense(grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None): def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state) decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]): with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse( return super(AdamWeightDecay,
grad, var, indices, **kwargs) self)._resource_apply_sparse(grad, var, indices, **kwargs)
def get_config(self): def get_config(self):
config = super(AdamWeightDecay, self).get_config() config = super(AdamWeightDecay, self).get_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment