optimizer_factory.py 6 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""Optimizer factory class."""
Rebecca Chen's avatar
Rebecca Chen committed
16
from typing import Callable, Optional, Union
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17

Le Hou's avatar
Le Hou committed
18
import gin
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
import tensorflow as tf
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
import tensorflow_addons.optimizers as tfa_optimizers

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
from official.modeling.optimization import slide_optimizer
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
from official.modeling.optimization import ema_optimizer
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
24
from official.modeling.optimization import lars_optimizer
Abdullah Rashwan's avatar
Abdullah Rashwan committed
25
26
27
28
29
30
31
32
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
from official.nlp import optimization as nlp_optimization

OPTIMIZERS_CLS = {
    'sgd': tf.keras.optimizers.SGD,
    'adam': tf.keras.optimizers.Adam,
    'adamw': nlp_optimization.AdamWeightDecay,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
    'lamb': tfa_optimizers.LAMB,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
34
35
    'rmsprop': tf.keras.optimizers.RMSprop,
    'lars': lars_optimizer.LARS,
Hao Wu's avatar
Hao Wu committed
36
    'adagrad': tf.keras.optimizers.Adagrad,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
    'slide': slide_optimizer.SLIDE
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
40
41
42
43
}

LR_CLS = {
    'stepwise': tf.keras.optimizers.schedules.PiecewiseConstantDecay,
    'polynomial': tf.keras.optimizers.schedules.PolynomialDecay,
    'exponential': tf.keras.optimizers.schedules.ExponentialDecay,
44
45
    'cosine': tf.keras.experimental.CosineDecay,
    'power': lr_schedule.DirectPowerDecay,
Le Hou's avatar
Le Hou committed
46
    'power_linear': lr_schedule.PowerAndLinearDecay,
Le Hou's avatar
Le Hou committed
47
    'power_with_offset': lr_schedule.PowerDecayWithOffset,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
48
49
50
}

WARMUP_CLS = {
Abdullah Rashwan's avatar
Abdullah Rashwan committed
51
52
    'linear': lr_schedule.LinearWarmup,
    'polynomial': lr_schedule.PolynomialWarmUp
Abdullah Rashwan's avatar
Abdullah Rashwan committed
53
54
55
}


Hongkun Yu's avatar
Hongkun Yu committed
56
class OptimizerFactory:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  """Optimizer factory class.

  This class builds learning rate and optimizer based on an optimization config.
  To use this class, you need to do the following:
  (1) Define optimization config, this includes optimizer, and learning rate
      schedule.
  (2) Initialize the class using the optimization config.
  (3) Build learning rate.
  (4) Build optimizer.

  This is a typical example for using this class:
  params = {
        'optimizer': {
            'type': 'sgd',
Abdullah Rashwan's avatar
Abdullah Rashwan committed
71
            'sgd': {'momentum': 0.9}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        },
        'learning_rate': {
            'type': 'stepwise',
            'stepwise': {'boundaries': [10000, 20000],
                         'values': [0.1, 0.01, 0.001]}
        },
        'warmup': {
            'type': 'linear',
            'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
        }
    }
  opt_config = OptimizationConfig(params)
  opt_factory = OptimizerFactory(opt_config)
  lr = opt_factory.build_learning_rate()
  optimizer = opt_factory.build_optimizer(lr)
  """

  def __init__(self, config: opt_cfg.OptimizationConfig):
    """Initializing OptimizerFactory.

    Args:
      config: OptimizationConfig instance contain optimization config.
    """
    self._config = config
    self._optimizer_config = config.optimizer.get()
    self._optimizer_type = config.optimizer.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
99
100
101
102
    self._use_ema = config.ema is not None
    self._ema_config = config.ema

    if self._optimizer_config is None:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
104
105
106
107
      raise ValueError('Optimizer type must be specified')

    self._lr_config = config.learning_rate.get()
    self._lr_type = config.learning_rate.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
108
109
110
    if self._lr_type is None:
      raise ValueError('Learning rate type must be specified')

Abdullah Rashwan's avatar
Abdullah Rashwan committed
111
112
113
114
115
116
117
    self._warmup_config = config.warmup.get()
    self._warmup_type = config.warmup.type

  def build_learning_rate(self):
    """Build learning rate.

    Builds learning rate from config. Learning rate schedule is built according
Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
119
    to the learning rate config. If learning rate type is consant,
    lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
121

    Returns:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
122
123
      tf.keras.optimizers.schedules.LearningRateSchedule instance. If
      learning rate type is consant, lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
124
    """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
125
126
    if self._lr_type == 'constant':
      lr = self._lr_config.learning_rate
Abdullah Rashwan's avatar
Abdullah Rashwan committed
127
128
129
130
131
132
133
134
    else:
      lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())

    if self._warmup_config:
      lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())

    return lr

Le Hou's avatar
Le Hou committed
135
  @gin.configurable
Abdullah Rashwan's avatar
Abdullah Rashwan committed
136
  def build_optimizer(
Le Hou's avatar
Le Hou committed
137
138
      self,
      lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
Rebecca Chen's avatar
Rebecca Chen committed
139
140
      postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
                                       tf.keras.optimizers.Optimizer]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
141
142
143
144
145
146
147
    """Build optimizer.

    Builds optimizer from config. It takes learning rate as input, and builds
    the optimizer according to the optimizer config. Typically, the learning
    rate built using self.build_lr() is passed as an argument to this method.

    Args:
Hongkun Yu's avatar
Hongkun Yu committed
148
149
      lr: A floating point value, or a
        tf.keras.optimizers.schedules.LearningRateSchedule instance.
Le Hou's avatar
Le Hou committed
150
151
      postprocessor: An optional function for postprocessing the optimizer. It
        takes an optimizer and returns an optimizer.
Hongkun Yu's avatar
Hongkun Yu committed
152

Abdullah Rashwan's avatar
Abdullah Rashwan committed
153
154
155
156
157
    Returns:
      tf.keras.optimizers.Optimizer instance.
    """

    optimizer_dict = self._optimizer_config.as_dict()
Abdullah Rashwan's avatar
Abdullah Rashwan committed
158
159
160
161
162
163
    ## Delete clipnorm and clipvalue if None
    if optimizer_dict['clipnorm'] is None:
      del optimizer_dict['clipnorm']
    if optimizer_dict['clipvalue'] is None:
      del optimizer_dict['clipvalue']

Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
165
166
    optimizer_dict['learning_rate'] = lr

    optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
167
168
169
170

    if self._use_ema:
      optimizer = ema_optimizer.ExponentialMovingAverage(
          optimizer, **self._ema_config.as_dict())
Le Hou's avatar
Le Hou committed
171
172
173
174
175
    if postprocessor:
      optimizer = postprocessor(optimizer)
    assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
        'OptimizerFactory.build_optimizer returning a non-optimizer object: '
        '{}'.format(optimizer))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
176

Abdullah Rashwan's avatar
Abdullah Rashwan committed
177
    return optimizer