optimizer_factory.py 7.46 KB
Newer Older
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
"""Optimizer factory class."""
Frederick Liu's avatar
Frederick Liu committed
16
from typing import Callable, Optional, Union, List, Tuple
Abdullah Rashwan's avatar
Abdullah Rashwan committed
17

Le Hou's avatar
Le Hou committed
18
import gin
Abdullah Rashwan's avatar
Abdullah Rashwan committed
19
import tensorflow as tf
Abdullah Rashwan's avatar
Abdullah Rashwan committed
20
21
import tensorflow_addons.optimizers as tfa_optimizers

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
from official.modeling.optimization import slide_optimizer
23
from official.modeling.optimization import adafactor_optimizer
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
from official.modeling.optimization import ema_optimizer
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
from official.modeling.optimization import lars_optimizer
Abdullah Rashwan's avatar
Abdullah Rashwan committed
26
27
28
29
30
31
32
33
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
from official.nlp import optimization as nlp_optimization

OPTIMIZERS_CLS = {
    'sgd': tf.keras.optimizers.SGD,
    'adam': tf.keras.optimizers.Adam,
    'adamw': nlp_optimization.AdamWeightDecay,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
34
    'lamb': tfa_optimizers.LAMB,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
36
    'rmsprop': tf.keras.optimizers.RMSprop,
    'lars': lars_optimizer.LARS,
Hao Wu's avatar
Hao Wu committed
37
    'adagrad': tf.keras.optimizers.Adagrad,
38
39
    'slide': slide_optimizer.SLIDE,
    'adafactor': adafactor_optimizer.Adafactor,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
40
41
42
}

LR_CLS = {
43
44
45
46
    'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
    'polynomial': lr_schedule.PolynomialDecayWithOffset,
    'exponential': lr_schedule.ExponentialDecayWithOffset,
    'cosine': lr_schedule.CosineDecayWithOffset,
47
    'power': lr_schedule.DirectPowerDecay,
Le Hou's avatar
Le Hou committed
48
    'power_linear': lr_schedule.PowerAndLinearDecay,
Le Hou's avatar
Le Hou committed
49
    'power_with_offset': lr_schedule.PowerDecayWithOffset,
Yeqing Li's avatar
Yeqing Li committed
50
    'step_cosine_with_offset': lr_schedule.StepConsineDecayWithOffset,
Abdullah Rashwan's avatar
Abdullah Rashwan committed
51
52
53
}

WARMUP_CLS = {
Abdullah Rashwan's avatar
Abdullah Rashwan committed
54
55
    'linear': lr_schedule.LinearWarmup,
    'polynomial': lr_schedule.PolynomialWarmUp
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
57
58
}


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def register_optimizer_cls(
    key: str, optimizer_config_cls: tf.keras.optimizers.Optimizer):
  """Register customize optimizer cls.

  The user will still need to subclass data classes in
  configs.optimization_config to be used with OptimizerFactory.

  Args:
    key: A string to that the optimizer_config_cls is registered with.
    optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
  """
  if key in OPTIMIZERS_CLS:
    raise ValueError('%s already registered in OPTIMIZER_CLS.' % key)
  OPTIMIZERS_CLS[key] = optimizer_config_cls


Hongkun Yu's avatar
Hongkun Yu committed
75
class OptimizerFactory:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  """Optimizer factory class.

  This class builds learning rate and optimizer based on an optimization config.
  To use this class, you need to do the following:
  (1) Define optimization config, this includes optimizer, and learning rate
      schedule.
  (2) Initialize the class using the optimization config.
  (3) Build learning rate.
  (4) Build optimizer.

  This is a typical example for using this class:
  params = {
        'optimizer': {
            'type': 'sgd',
Abdullah Rashwan's avatar
Abdullah Rashwan committed
90
            'sgd': {'momentum': 0.9}
Abdullah Rashwan's avatar
Abdullah Rashwan committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        },
        'learning_rate': {
            'type': 'stepwise',
            'stepwise': {'boundaries': [10000, 20000],
                         'values': [0.1, 0.01, 0.001]}
        },
        'warmup': {
            'type': 'linear',
            'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
        }
    }
  opt_config = OptimizationConfig(params)
  opt_factory = OptimizerFactory(opt_config)
  lr = opt_factory.build_learning_rate()
  optimizer = opt_factory.build_optimizer(lr)
  """

  def __init__(self, config: opt_cfg.OptimizationConfig):
    """Initializing OptimizerFactory.

    Args:
      config: OptimizationConfig instance contain optimization config.
    """
    self._config = config
    self._optimizer_config = config.optimizer.get()
    self._optimizer_type = config.optimizer.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
119
120
121
    self._use_ema = config.ema is not None
    self._ema_config = config.ema

    if self._optimizer_config is None:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
122
123
124
125
126
      raise ValueError('Optimizer type must be specified')

    self._lr_config = config.learning_rate.get()
    self._lr_type = config.learning_rate.type

Abdullah Rashwan's avatar
Abdullah Rashwan committed
127
128
129
    if self._lr_type is None:
      raise ValueError('Learning rate type must be specified')

Abdullah Rashwan's avatar
Abdullah Rashwan committed
130
131
132
133
134
135
136
    self._warmup_config = config.warmup.get()
    self._warmup_type = config.warmup.type

  def build_learning_rate(self):
    """Build learning rate.

    Builds learning rate from config. Learning rate schedule is built according
Abdullah Rashwan's avatar
Abdullah Rashwan committed
137
138
    to the learning rate config. If learning rate type is consant,
    lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
140

    Returns:
Abdullah Rashwan's avatar
Abdullah Rashwan committed
141
142
      tf.keras.optimizers.schedules.LearningRateSchedule instance. If
      learning rate type is consant, lr_config.learning_rate is returned.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
143
    """
Abdullah Rashwan's avatar
Abdullah Rashwan committed
144
145
    if self._lr_type == 'constant':
      lr = self._lr_config.learning_rate
Abdullah Rashwan's avatar
Abdullah Rashwan committed
146
147
148
149
150
151
152
153
    else:
      lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())

    if self._warmup_config:
      lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())

    return lr

Le Hou's avatar
Le Hou committed
154
  @gin.configurable
Abdullah Rashwan's avatar
Abdullah Rashwan committed
155
  def build_optimizer(
Le Hou's avatar
Le Hou committed
156
157
      self,
      lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
Frederick Liu's avatar
Frederick Liu committed
158
159
160
      gradient_transformers: Optional[List[Callable[
          [List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor, tf.Tensor]]
      ]]] = None,
Rebecca Chen's avatar
Rebecca Chen committed
161
162
      postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
                                       tf.keras.optimizers.Optimizer]] = None):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
163
164
165
166
167
168
169
    """Build optimizer.

    Builds optimizer from config. It takes learning rate as input, and builds
    the optimizer according to the optimizer config. Typically, the learning
    rate built using self.build_lr() is passed as an argument to this method.

    Args:
Hongkun Yu's avatar
Hongkun Yu committed
170
171
      lr: A floating point value, or a
        tf.keras.optimizers.schedules.LearningRateSchedule instance.
Frederick Liu's avatar
Frederick Liu committed
172
173
174
175
176
      gradient_transformers: Optional list of functions to use to transform
        gradients before applying updates to Variables. The functions are
        applied after gradient_aggregator. The functions should accept and
        return a list of (gradient, variable) tuples. clipvalue, clipnorm,
        global_clipnorm should not be set when gradient_transformers is passed.
Le Hou's avatar
Le Hou committed
177
178
      postprocessor: An optional function for postprocessing the optimizer. It
        takes an optimizer and returns an optimizer.
Hongkun Yu's avatar
Hongkun Yu committed
179

Abdullah Rashwan's avatar
Abdullah Rashwan committed
180
181
182
183
184
    Returns:
      tf.keras.optimizers.Optimizer instance.
    """

    optimizer_dict = self._optimizer_config.as_dict()
Frederick Liu's avatar
Frederick Liu committed
185
    ## Delete clipnorm, clipvalue, global_clipnorm if None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
186
187
188
189
    if optimizer_dict['clipnorm'] is None:
      del optimizer_dict['clipnorm']
    if optimizer_dict['clipvalue'] is None:
      del optimizer_dict['clipvalue']
Frederick Liu's avatar
Frederick Liu committed
190
191
    if optimizer_dict['global_clipnorm'] is None:
      del optimizer_dict['global_clipnorm']
Abdullah Rashwan's avatar
Abdullah Rashwan committed
192

Abdullah Rashwan's avatar
Abdullah Rashwan committed
193
    optimizer_dict['learning_rate'] = lr
Frederick Liu's avatar
Frederick Liu committed
194
195
    if gradient_transformers is not None:
      optimizer_dict['gradient_transformers'] = gradient_transformers
Abdullah Rashwan's avatar
Abdullah Rashwan committed
196
197

    optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
Abdullah Rashwan's avatar
Abdullah Rashwan committed
198
199
200
201

    if self._use_ema:
      optimizer = ema_optimizer.ExponentialMovingAverage(
          optimizer, **self._ema_config.as_dict())
Le Hou's avatar
Le Hou committed
202
203
204
205
206
    if postprocessor:
      optimizer = postprocessor(optimizer)
    assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
        'OptimizerFactory.build_optimizer returning a non-optimizer object: '
        '{}'.format(optimizer))
Abdullah Rashwan's avatar
Abdullah Rashwan committed
207

Abdullah Rashwan's avatar
Abdullah Rashwan committed
208
    return optimizer