optimizer_factory.py 10.3 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""Optimizer factory class."""
Frederick Liu's avatar
Frederick Liu committed
16
from typing import Callable, Optional, Union, List, Tuple
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17

Le Hou's avatar
Le Hou committed
18
import gin
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
import tensorflow as tf
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
import tensorflow_addons.optimizers as tfa_optimizers
21

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
from official.modeling.optimization import slide_optimizer
23
from official.modeling.optimization import adafactor_optimizer
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
from official.modeling.optimization import ema_optimizer
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
from official.modeling.optimization import lars_optimizer
26
from official.modeling.optimization import legacy_adamw
Abdullah Rashwan's avatar
Abdullah Rashwan committed
27
28
29
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg

30
31
# Optimizer CLS to be used in both legacy and new path.
SHARED_OPTIMIZERS = {
32
33
    'sgd_experimental': tf.keras.optimizers.experimental.SGD,
    'adam_experimental': tf.keras.optimizers.experimental.Adam,
34
    'adamw': legacy_adamw.AdamWeightDecay,
35
    'adamw_experimental': tf.keras.optimizers.experimental.AdamW,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
36
    'lamb': tfa_optimizers.LAMB,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
    'lars': lars_optimizer.LARS,
38
39
    'slide': slide_optimizer.SLIDE,
    'adafactor': adafactor_optimizer.Adafactor,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
40
41
}

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
LEGACY_OPTIMIZERS_CLS = {
    'sgd': tf.keras.optimizers.legacy.SGD,
    'adam': tf.keras.optimizers.legacy.Adam,
    'rmsprop': tf.keras.optimizers.legacy.RMSprop,
    'adagrad': tf.keras.optimizers.legacy.Adagrad,
}
LEGACY_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)

NEW_OPTIMIZERS_CLS = {
    'sgd': tf.keras.optimizers.experimental.SGD,
    'adam': tf.keras.optimizers.experimental.Adam,
    'rmsprop': tf.keras.optimizers.experimental.RMSprop,
    'adagrad': tf.keras.optimizers.experimental.Adagrad,
}
NEW_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)

Abdullah Rashwan's avatar
Abdullah Rashwan committed
58
LR_CLS = {
59
60
61
62
    'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
    'polynomial': lr_schedule.PolynomialDecayWithOffset,
    'exponential': lr_schedule.ExponentialDecayWithOffset,
    'cosine': lr_schedule.CosineDecayWithOffset,
63
    'power': lr_schedule.DirectPowerDecay,
Le Hou's avatar
Le Hou committed
64
    'power_linear': lr_schedule.PowerAndLinearDecay,
Le Hou's avatar
Le Hou committed
65
    'power_with_offset': lr_schedule.PowerDecayWithOffset,
66
    'step_cosine_with_offset': lr_schedule.StepCosineDecayWithOffset,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
67
68
69
}

WARMUP_CLS = {
Abdullah Rashwan's avatar
Abdullah Rashwan committed
70
71
    'linear': lr_schedule.LinearWarmup,
    'polynomial': lr_schedule.PolynomialWarmUp
Abdullah Rashwan's avatar
Abdullah Rashwan committed
72
73
74
}


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
75
def register_optimizer_cls(key: str,
76
77
78
79
80
81
                           optimizer_config_cls: Union[
                               tf.keras.optimizers.Optimizer,
                               tf.keras.optimizers.legacy.Optimizer,
                               tf.keras.optimizers.experimental.Optimizer
                           ],
                           use_legacy_optimizer: bool = True):
82
83
84
85
86
87
88
89
  """Register customize optimizer cls.

  The user will still need to subclass data classes in
  configs.optimization_config to be used with OptimizerFactory.

  Args:
    key: A string to that the optimizer_config_cls is registered with.
    optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
90
    use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
91
  """
92
93
94
95
96
97
98
99
  if use_legacy_optimizer:
    if key in LEGACY_OPTIMIZERS_CLS:
      raise ValueError('%s already registered in LEGACY_OPTIMIZERS_CLS.' % key)
    LEGACY_OPTIMIZERS_CLS[key] = optimizer_config_cls
  else:
    if key in NEW_OPTIMIZERS_CLS:
      raise ValueError('%s already registered in NEW_OPTIMIZERS_CLS.' % key)
    NEW_OPTIMIZERS_CLS[key] = optimizer_config_cls
100
101


Hongkun Yu's avatar
Hongkun Yu committed
102
class OptimizerFactory:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
104
105
106
107
108
109
110
111
112
113
  """Optimizer factory class.

  This class builds learning rate and optimizer based on an optimization config.
  To use this class, you need to do the following:
  (1) Define optimization config, this includes optimizer, and learning rate
      schedule.
  (2) Initialize the class using the optimization config.
  (3) Build learning rate.
  (4) Build optimizer.

  This is a typical example for using this class:
Mark Daoust's avatar
Mark Daoust committed
114
115

  ```
Abdullah Rashwan's avatar
Abdullah Rashwan committed
116
117
118
  params = {
        'optimizer': {
            'type': 'sgd',
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
            'sgd': {'momentum': 0.9}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        },
        'learning_rate': {
            'type': 'stepwise',
            'stepwise': {'boundaries': [10000, 20000],
                         'values': [0.1, 0.01, 0.001]}
        },
        'warmup': {
            'type': 'linear',
            'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
        }
    }
  opt_config = OptimizationConfig(params)
  opt_factory = OptimizerFactory(opt_config)
  lr = opt_factory.build_learning_rate()
  optimizer = opt_factory.build_optimizer(lr)
Mark Daoust's avatar
Mark Daoust committed
135
  ```
Abdullah Rashwan's avatar
Abdullah Rashwan committed
136
137
138
139
140
141
142
143
144
145
146
147
  """

  def __init__(self, config: opt_cfg.OptimizationConfig):
    """Initializing OptimizerFactory.

    Args:
      config: OptimizationConfig instance contain optimization config.
    """
    self._config = config
    self._optimizer_config = config.optimizer.get()
    self._optimizer_type = config.optimizer.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
148
149
150
151
    self._use_ema = config.ema is not None
    self._ema_config = config.ema

    if self._optimizer_config is None:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
152
153
154
155
156
      raise ValueError('Optimizer type must be specified')

    self._lr_config = config.learning_rate.get()
    self._lr_type = config.learning_rate.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
158
159
    if self._lr_type is None:
      raise ValueError('Learning rate type must be specified')

Abdullah Rashwan's avatar
Abdullah Rashwan committed
160
161
162
163
164
165
166
    self._warmup_config = config.warmup.get()
    self._warmup_type = config.warmup.type

  def build_learning_rate(self):
    """Build learning rate.

    Builds learning rate from config. Learning rate schedule is built according
Abdullah Rashwan's avatar
Abdullah Rashwan committed
167
168
    to the learning rate config. If learning rate type is consant,
    lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
169
170

    Returns:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
171
172
      tf.keras.optimizers.schedules.LearningRateSchedule instance. If
      learning rate type is consant, lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
173
    """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
174
175
    if self._lr_type == 'constant':
      lr = self._lr_config.learning_rate
Abdullah Rashwan's avatar
Abdullah Rashwan committed
176
177
178
179
180
181
182
183
    else:
      lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())

    if self._warmup_config:
      lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())

    return lr

Le Hou's avatar
Le Hou committed
184
  @gin.configurable
Abdullah Rashwan's avatar
Abdullah Rashwan committed
185
  def build_optimizer(
Le Hou's avatar
Le Hou committed
186
187
      self,
      lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
188
189
190
      gradient_aggregator: Optional[Callable[
          [List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
                                                          tf.Tensor]]]] = None,
Frederick Liu's avatar
Frederick Liu committed
191
      gradient_transformers: Optional[List[Callable[
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
193
          [List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
                                                          tf.Tensor]]]]] = None,
Rebecca Chen's avatar
Rebecca Chen committed
194
      postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
195
196
                                       tf.keras.optimizers.Optimizer]] = None,
      use_legacy_optimizer: bool = True):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
197
198
199
200
201
202
203
    """Build optimizer.

    Builds optimizer from config. It takes learning rate as input, and builds
    the optimizer according to the optimizer config. Typically, the learning
    rate built using self.build_lr() is passed as an argument to this method.

    Args:
Hongkun Yu's avatar
Hongkun Yu committed
204
205
      lr: A floating point value, or a
        tf.keras.optimizers.schedules.LearningRateSchedule instance.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
206
      gradient_aggregator: Optional function to overwrite gradient aggregation.
Frederick Liu's avatar
Frederick Liu committed
207
208
209
210
211
      gradient_transformers: Optional list of functions to use to transform
        gradients before applying updates to Variables. The functions are
        applied after gradient_aggregator. The functions should accept and
        return a list of (gradient, variable) tuples. clipvalue, clipnorm,
        global_clipnorm should not be set when gradient_transformers is passed.
Le Hou's avatar
Le Hou committed
212
213
      postprocessor: An optional function for postprocessing the optimizer. It
        takes an optimizer and returns an optimizer.
214
      use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
Hongkun Yu's avatar
Hongkun Yu committed
215

Abdullah Rashwan's avatar
Abdullah Rashwan committed
216
    Returns:
217
      `tf.keras.optimizers.legacy.Optimizer` or
Chen Qian's avatar
Chen Qian committed
218
      `tf.keras.optimizers.experimental.Optimizer` instance.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
219
220
221
    """

    optimizer_dict = self._optimizer_config.as_dict()
Frederick Liu's avatar
Frederick Liu committed
222
    ## Delete clipnorm, clipvalue, global_clipnorm if None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
223
224
225
226
    if optimizer_dict['clipnorm'] is None:
      del optimizer_dict['clipnorm']
    if optimizer_dict['clipvalue'] is None:
      del optimizer_dict['clipvalue']
Frederick Liu's avatar
Frederick Liu committed
227
228
    if optimizer_dict['global_clipnorm'] is None:
      del optimizer_dict['global_clipnorm']
Abdullah Rashwan's avatar
Abdullah Rashwan committed
229

Abdullah Rashwan's avatar
Abdullah Rashwan committed
230
    optimizer_dict['learning_rate'] = lr
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
231
232
    if gradient_aggregator is not None:
      optimizer_dict['gradient_aggregator'] = gradient_aggregator
Frederick Liu's avatar
Frederick Liu committed
233
234
    if gradient_transformers is not None:
      optimizer_dict['gradient_transformers'] = gradient_transformers
Abdullah Rashwan's avatar
Abdullah Rashwan committed
235

236
237
238
    if use_legacy_optimizer:
      optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
    else:
239
240
241
242
243
      if 'decay' in optimizer_dict:
        raise ValueError(
            '`decay` is deprecated in new Keras optimizer, please reflect the '
            'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
            'legacy optimizer.')
244
      optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
245
246

    if self._use_ema:
247
248
249
250
      if not use_legacy_optimizer:
        raise ValueError(
            'EMA can only work with the legacy optimizer, please set '
            '`use_legacy_optimizer=True`.')
Abdullah Rashwan's avatar
Abdullah Rashwan committed
251
252
      optimizer = ema_optimizer.ExponentialMovingAverage(
          optimizer, **self._ema_config.as_dict())
Le Hou's avatar
Le Hou committed
253
254
    if postprocessor:
      optimizer = postprocessor(optimizer)
255
256
257
258
259
260
261
262
263
264
265
266
    if isinstance(optimizer, tf.keras.optimizers.Optimizer):
      return optimizer
    # The following check makes sure the function won't break in older TF
    # version because of missing the experimental/legacy package.
    if hasattr(tf.keras.optimizers, 'experimental'):
      if isinstance(optimizer, tf.keras.optimizers.experimental.Optimizer):
        return optimizer
    if hasattr(tf.keras.optimizers, 'legacy'):
      if isinstance(optimizer, tf.keras.optimizers.legacy.Optimizer):
        return optimizer
    raise TypeError('OptimizerFactory.build_optimizer returning a '
                    'non-optimizer object: {}'.format(optimizer))