optimization.py 3.78 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Legacy functions and classes related to optimization."""
16

17
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
18
import gin
Hongkun Yu's avatar
Hongkun Yu committed
19
import tensorflow as tf
20
import tensorflow_addons.optimizers as tfa_optimizers
21
22
23
from official.modeling.optimization import legacy_adamw

AdamWeightDecay = legacy_adamw.AdamWeightDecay
24
25
26


class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
  """Applies a warmup schedule on a given learning rate decay schedule."""
28

29
30
31
32
33
34
  def __init__(self,
               initial_learning_rate,
               decay_schedule_fn,
               warmup_steps,
               power=1.0,
               name=None):
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    super(WarmUp, self).__init__()
    self.initial_learning_rate = initial_learning_rate
    self.warmup_steps = warmup_steps
    self.power = power
    self.decay_schedule_fn = decay_schedule_fn
    self.name = name

  def __call__(self, step):
    with tf.name_scope(self.name or 'WarmUp') as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
      warmup_percent_done = global_step_float / warmup_steps_float
      warmup_learning_rate = (
          self.initial_learning_rate *
          tf.math.pow(warmup_percent_done, self.power))
52
53
54
55
56
      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: self.decay_schedule_fn(step),
          name=name)
57
58
59
60
61
62
63
64
65
66
67

  def get_config(self):
    return {
        'initial_learning_rate': self.initial_learning_rate,
        'decay_schedule_fn': self.decay_schedule_fn,
        'warmup_steps': self.warmup_steps,
        'power': self.power,
        'name': self.name
    }


Hongkun Yu's avatar
Hongkun Yu committed
68
@gin.configurable
69
70
71
def create_optimizer(init_lr,
                     num_train_steps,
                     num_warmup_steps,
72
                     end_lr=0.0,
73
                     optimizer_type='adamw',
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
74
75
                     beta_1=0.9,
                     poly_power=1.0):
76
77
  """Creates an optimizer with learning rate schedule."""
  # Implements linear decay of the learning rate.
78
  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
79
80
      initial_learning_rate=init_lr,
      decay_steps=num_train_steps,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
81
82
      end_learning_rate=end_lr,
      power=poly_power)
83
  if num_warmup_steps:
84
85
86
87
    lr_schedule = WarmUp(
        initial_learning_rate=init_lr,
        decay_schedule_fn=lr_schedule,
        warmup_steps=num_warmup_steps)
88
89
90
91

  if optimizer_type == 'adamw':
    logging.info('using Adamw optimizer')
    optimizer = AdamWeightDecay(
92
        learning_rate=lr_schedule,
93
        weight_decay_rate=0.01,
94
        beta_1=beta_1,
95
96
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
97
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
98
99
100
  elif optimizer_type == 'lamb':
    logging.info('using Lamb optimizer')
    optimizer = tfa_optimizers.LAMB(
101
        learning_rate=lr_schedule,
102
        weight_decay_rate=0.01,
103
        beta_1=beta_1,
104
105
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
106
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
107
108
109
  else:
    raise ValueError('Unsupported optimizer type: ', optimizer_type)

110
  return optimizer