optimization.py 3.72 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Legacy functions and classes related to optimization."""
16

17
from absl import logging
Hongkun Yu's avatar
Hongkun Yu committed
18
import gin
Hongkun Yu's avatar
Hongkun Yu committed
19
import tensorflow as tf
20
import tensorflow_addons.optimizers as tfa_optimizers
21
22
23
from official.modeling.optimization import legacy_adamw

AdamWeightDecay = legacy_adamw.AdamWeightDecay
24
25
26


class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
27
  """Applies a warmup schedule on a given learning rate decay schedule."""
28

29
30
31
32
33
34
  def __init__(self,
               initial_learning_rate,
               decay_schedule_fn,
               warmup_steps,
               power=1.0,
               name=None):
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    super(WarmUp, self).__init__()
    self.initial_learning_rate = initial_learning_rate
    self.warmup_steps = warmup_steps
    self.power = power
    self.decay_schedule_fn = decay_schedule_fn
    self.name = name

  def __call__(self, step):
    with tf.name_scope(self.name or 'WarmUp') as name:
      # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
      # learning rate will be `global_step/num_warmup_steps * init_lr`.
      global_step_float = tf.cast(step, tf.float32)
      warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
      warmup_percent_done = global_step_float / warmup_steps_float
      warmup_learning_rate = (
          self.initial_learning_rate *
          tf.math.pow(warmup_percent_done, self.power))
52
53
54
55
56
      return tf.cond(
          global_step_float < warmup_steps_float,
          lambda: warmup_learning_rate,
          lambda: self.decay_schedule_fn(step),
          name=name)
57
58
59
60
61
62
63
64
65
66
67

  def get_config(self):
    return {
        'initial_learning_rate': self.initial_learning_rate,
        'decay_schedule_fn': self.decay_schedule_fn,
        'warmup_steps': self.warmup_steps,
        'power': self.power,
        'name': self.name
    }


Hongkun Yu's avatar
Hongkun Yu committed
68
@gin.configurable
69
70
71
def create_optimizer(init_lr,
                     num_train_steps,
                     num_warmup_steps,
72
                     end_lr=0.0,
73
74
                     optimizer_type='adamw',
                     beta_1=0.9):
75
76
  """Creates an optimizer with learning rate schedule."""
  # Implements linear decay of the learning rate.
77
  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
78
79
      initial_learning_rate=init_lr,
      decay_steps=num_train_steps,
80
      end_learning_rate=end_lr)
81
  if num_warmup_steps:
82
83
84
85
    lr_schedule = WarmUp(
        initial_learning_rate=init_lr,
        decay_schedule_fn=lr_schedule,
        warmup_steps=num_warmup_steps)
86
87
88
89

  if optimizer_type == 'adamw':
    logging.info('using Adamw optimizer')
    optimizer = AdamWeightDecay(
90
        learning_rate=lr_schedule,
91
        weight_decay_rate=0.01,
92
        beta_1=beta_1,
93
94
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
96
97
98
  elif optimizer_type == 'lamb':
    logging.info('using Lamb optimizer')
    optimizer = tfa_optimizers.LAMB(
99
        learning_rate=lr_schedule,
100
        weight_decay_rate=0.01,
101
        beta_1=beta_1,
102
103
        beta_2=0.999,
        epsilon=1e-6,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
104
        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])
105
106
107
  else:
    raise ValueError('Unsupported optimizer type: ', optimizer_type)

108
  return optimizer