# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions and classes related to optimization (weight updates).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import tensorflow as tf class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): """Applys a warmup schedule on a given learning rate decay schedule.""" def __init__( self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None): super(WarmUp, self).__init__() self.initial_learning_rate = initial_learning_rate self.warmup_steps = warmup_steps self.power = power self.decay_schedule_fn = decay_schedule_fn self.name = name def __call__(self, step): with tf.name_scope(self.name or 'WarmUp') as name: # Implements polynomial warmup. i.e., if global_step < warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. global_step_float = tf.cast(step, tf.float32) warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) warmup_percent_done = global_step_float / warmup_steps_float warmup_learning_rate = ( self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)) return tf.cond(global_step_float < warmup_steps_float, lambda: warmup_learning_rate, lambda: self.decay_schedule_fn(step), name=name) def get_config(self): return { 'initial_learning_rate': self.initial_learning_rate, 'decay_schedule_fn': self.decay_schedule_fn, 'warmup_steps': self.warmup_steps, 'power': self.power, 'name': self.name } def create_optimizer(init_lr, num_train_steps, num_warmup_steps): """Creates an optimizer with learning rate schedule.""" # Implements linear decay of the learning rate. learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=init_lr, decay_steps=num_train_steps, end_learning_rate=0.0) if num_warmup_steps: learning_rate_fn = WarmUp(initial_learning_rate=init_lr, decay_schedule_fn=learning_rate_fn, warmup_steps=num_warmup_steps) optimizer = AdamWeightDecay( learning_rate=learning_rate_fn, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=['layer_norm', 'bias']) return optimizer class AdamWeightDecay(tf.keras.optimizers.Adam): """Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact with the m and v parameters in strange ways. Instead we want ot decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent to adding the square of the weights to the loss with plain (non-momentum) SGD. """ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-7, amsgrad=False, weight_decay_rate=0.0, exclude_from_weight_decay=None, name='AdamWeightDecay', **kwargs): super(AdamWeightDecay, self).__init__( learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) self._set_hyper('weight_decay_rate', weight_decay_rate) self._exclude_from_weight_decay = exclude_from_weight_decay @classmethod def from_config(cls, config): """Creates an optimizer from its config with WarmUp custom object.""" custom_objects = {'WarmUp': WarmUp} return super(AdamWeightDecay, cls).from_config( config, custom_objects=custom_objects) def _decay_weights_op(self, var, learning_rate): do_decay = self._do_use_weight_decay(var.name) if do_decay: return var.assign_sub( learning_rate * var * self._get_hyper('weight_decay_rate'), use_locking=self._use_locking) return tf.no_op() def apply_gradients(self, grads_and_vars, name=None): grads, tvars = list(zip(*grads_and_vars)) (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars)) def _resource_apply_dense(self, grad, var): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) with tf.control_dependencies([self._decay_weights_op(var, lr_t)]): return super(AdamWeightDecay, self)._resource_apply_dense( grad, var) def _resource_apply_sparse(self, grad, var, indices): var_dtype = var.dtype.base_dtype lr_t = self._decayed_lr(var_dtype) with tf.control_dependencies([self._decay_weights_op(var, lr_t)]): return super(AdamWeightDecay, self)._resource_apply_sparse( grad, var, indices) def get_config(self): config = super(AdamWeightDecay, self).get_config() config.update({ 'weight_decay_rate': self._serialize_hyperparameter('weight_decay_rate'), }) return config def _do_use_weight_decay(self, param_name): """Whether to use L2 weight decay for `param_name`.""" if self.weight_decay_rate == 0: return False if self._exclude_from_weight_decay: for r in self._exclude_from_weight_decay: if re.search(r, param_name) is not None: return False return True