optimization.py 8.81 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
16
17
18
"""Functions and classes related to optimization (weight updates)."""

import re

19
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
20
import gin
Hongkun Yu's avatar
Hongkun Yu committed
21
import tensorflow as tf
22
import tensorflow_addons.optimizers as tfa_optimizers
23
24
25


class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
26
  """Applies a warmup schedule on a given learning rate decay schedule."""
27

28
29
30
31
32
33
  def __init__(self,
               initial_learning_rate,
               decay_schedule_fn,
               warmup_steps,
               power=1.0,
               name=None):
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    super(WarmUp, self).__init__()
    self.initial_learning_rate = initial_learning_rate
    self.warmup_steps = warmup_steps
    self.power = power
    self.decay_schedule_fn = decay_schedule_fn
    self.name = name

  def __call__(self, step):
    with tf.name_scope(self.name or 'WarmUp') as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
      warmup_percent_done = global_step_float / warmup_steps_float
      warmup_learning_rate = (
          self.initial_learning_rate *
          tf.math.pow(warmup_percent_done, self.power))
51
52
53
54
55
      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: self.decay_schedule_fn(step),
          name=name)
56
57
58
59
60
61
62
63
64
65
66

  def get_config(self):
    return {
        'initial_learning_rate': self.initial_learning_rate,
        'decay_schedule_fn': self.decay_schedule_fn,
        'warmup_steps': self.warmup_steps,
        'power': self.power,
        'name': self.name
    }


Hongkun Yu's avatar
Hongkun Yu committed
67
@gin.configurable
68
69
70
def create_optimizer(init_lr,
                     num_train_steps,
                     num_warmup_steps,
71
                     end_lr=0.0,
72
73
                     optimizer_type='adamw',
                     beta_1=0.9):
74
75
  """Creates an optimizer with learning rate schedule."""
  # Implements linear decay of the learning rate.
76
  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
77
78
      initial_learning_rate=init_lr,
      decay_steps=num_train_steps,
79
      end_learning_rate=end_lr)
80
  if num_warmup_steps:
81
82
83
84
    lr_schedule = WarmUp(
        initial_learning_rate=init_lr,
        decay_schedule_fn=lr_schedule,
        warmup_steps=num_warmup_steps)
85
86
87
88

  if optimizer_type == 'adamw':
    logging.info('using Adamw optimizer')
    optimizer = AdamWeightDecay(
89
        learning_rate=lr_schedule,
90
        weight_decay_rate=0.01,
91
        beta_1=beta_1,
92
93
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
94
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
95
96
97
  elif optimizer_type == 'lamb':
    logging.info('using Lamb optimizer')
    optimizer = tfa_optimizers.LAMB(
98
        learning_rate=lr_schedule,
99
        weight_decay_rate=0.01,
100
        beta_1=beta_1,
101
102
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
104
105
106
  else:
    raise ValueError('Unsupported optimizer type: ', optimizer_type)

107
108
109
110
111
112
113
114
115
116
  return optimizer


class AdamWeightDecay(tf.keras.optimizers.Adam):
  """Adam enables L2 weight decay and clip_by_global_norm on gradients.

  Just adding the square of the weights to the loss function is *not* the
  correct way of using L2 regularization/weight decay with Adam, since that will
  interact with the m and v parameters in strange ways.

117
  Instead we want to decay the weights in a manner that doesn't interact with
118
119
120
121
122
123
124
125
126
127
128
  the m/v parameters. This is equivalent to adding the square of the weights to
  the loss with plain (non-momentum) SGD.
  """

  def __init__(self,
               learning_rate=0.001,
               beta_1=0.9,
               beta_2=0.999,
               epsilon=1e-7,
               amsgrad=False,
               weight_decay_rate=0.0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
               include_in_weight_decay=None,
130
               exclude_from_weight_decay=None,
Hongkun Yu's avatar
Hongkun Yu committed
131
               gradient_clip_norm=1.0,
132
133
               name='AdamWeightDecay',
               **kwargs):
134
135
    super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
                                          epsilon, amsgrad, name, **kwargs)
136
    self.weight_decay_rate = weight_decay_rate
Hongkun Yu's avatar
Hongkun Yu committed
137
    self.gradient_clip_norm = gradient_clip_norm
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
138
    self._include_in_weight_decay = include_in_weight_decay
139
    self._exclude_from_weight_decay = exclude_from_weight_decay
Hongkun Yu's avatar
Hongkun Yu committed
140
    logging.info('gradient_clip_norm=%f', gradient_clip_norm)
141
142
143
144
145
146
147
148

  @classmethod
  def from_config(cls, config):
    """Creates an optimizer from its config with WarmUp custom object."""
    custom_objects = {'WarmUp': WarmUp}
    return super(AdamWeightDecay, cls).from_config(
        config, custom_objects=custom_objects)

149
  def _prepare_local(self, var_device, var_dtype, apply_state):
Rebecca Chen's avatar
Rebecca Chen committed
150
    super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,  # pytype: disable=attribute-error  # typed-keras
151
                                                apply_state)
Scott Zhu's avatar
Scott Zhu committed
152
    apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
153
154
155
        self.weight_decay_rate, name='adam_weight_decay_rate')

  def _decay_weights_op(self, var, learning_rate, apply_state):
156
157
158
159
    do_decay = self._do_use_weight_decay(var.name)
    if do_decay:
      return var.assign_sub(
          learning_rate * var *
Scott Zhu's avatar
Scott Zhu committed
160
          apply_state[(var.device, var.dtype.base_dtype)]['weight_decay_rate'],
161
162
163
          use_locking=self._use_locking)
    return tf.no_op()

Zongwei Zhou's avatar
Zongwei Zhou committed
164
165
166
  def apply_gradients(self,
                      grads_and_vars,
                      name=None,
167
                      experimental_aggregate_gradients=True):
168
    grads, tvars = list(zip(*grads_and_vars))
Hongkun Yu's avatar
Hongkun Yu committed
169
    if experimental_aggregate_gradients and self.gradient_clip_norm > 0.0:
170
171
172
173
174
      # when experimental_aggregate_gradients = False, apply_gradients() no
      # longer implicitly allreduce gradients, users manually allreduce gradient
      # and passed the allreduced grads_and_vars. For now, the
      # clip_by_global_norm will be moved to before the explicit allreduce to
      # keep the math the same as TF 1 and pre TF 2.2 implementation.
175
176
      (grads, _) = tf.clip_by_global_norm(
          grads, clip_norm=self.gradient_clip_norm)
Zongwei Zhou's avatar
Zongwei Zhou committed
177
178
179
    return super(AdamWeightDecay, self).apply_gradients(
        zip(grads, tvars),
        name=name,
180
        experimental_aggregate_gradients=experimental_aggregate_gradients)
181

182
  def _get_lr(self, var_device, var_dtype, apply_state):
183
    """Retrieves the learning rate with the given state."""
184
185
    if apply_state is None:
      return self._decayed_lr_t[var_dtype], {}
186

187
188
189
190
191
    apply_state = apply_state or {}
    coefficients = apply_state.get((var_device, var_dtype))
    if coefficients is None:
      coefficients = self._fallback_apply_state(var_device, var_dtype)
      apply_state[(var_device, var_dtype)] = coefficients
192

193
194
195
    return coefficients['lr_t'], dict(apply_state=apply_state)

  def _resource_apply_dense(self, grad, var, apply_state=None):
196
197
    lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
    decay = self._decay_weights_op(var, lr_t, apply_state)
198
    with tf.control_dependencies([decay]):
199
      return super(AdamWeightDecay,
Rebecca Chen's avatar
Rebecca Chen committed
200
                   self)._resource_apply_dense(grad, var, **kwargs)  # pytype: disable=attribute-error  # typed-keras
201

202
  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
203
204
    lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
    decay = self._decay_weights_op(var, lr_t, apply_state)
205
    with tf.control_dependencies([decay]):
206
      return super(AdamWeightDecay,
Rebecca Chen's avatar
Rebecca Chen committed
207
                   self)._resource_apply_sparse(grad, var, indices, **kwargs)  # pytype: disable=attribute-error  # typed-keras
208
209
210
211

  def get_config(self):
    config = super(AdamWeightDecay, self).get_config()
    config.update({
212
        'weight_decay_rate': self.weight_decay_rate,
213
214
215
216
217
    })
    return config

  def _do_use_weight_decay(self, param_name):
    """Whether to use L2 weight decay for `param_name`."""
218
219
    if self.weight_decay_rate == 0:
      return False
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
220
221
222
223
224
225

    if self._include_in_weight_decay:
      for r in self._include_in_weight_decay:
        if re.search(r, param_name) is not None:
          return True

226
227
228
229
230
    if self._exclude_from_weight_decay:
      for r in self._exclude_from_weight_decay:
        if re.search(r, param_name) is not None:
          return False
    return True