optimizer_config.py 9.65 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
16
17
18
19
20
21
22
"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
24
25
26
27
28
29
30
class BaseOptimizerConfig(base_config.Config):
  """Base optimizer config.

  Attributes:
    clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
      their L2 norm exceeds this value.
    clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
      their absolute value exceeds this value.
Hongkun Yu's avatar
Hongkun Yu committed
31
    global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
Hao Wu's avatar
Hao Wu committed
32
      clipped so that their global norm is no higher than this value
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
34
35
  """
  clipnorm: Optional[float] = None
  clipvalue: Optional[float] = None
Hongkun Yu's avatar
Hongkun Yu committed
36
  global_clipnorm: Optional[float] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
38
39
40


@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf.keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


Chen Qian's avatar
Chen Qian committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# TODO(b/216129465): Merge this config with SGDConfig after the experimental
# optimizer graduates.
@dataclasses.dataclass
class SGDExperimentalConfig(BaseOptimizerConfig):
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of
  `tf.keras.optimizer.experimental.SGD`.

  Attributes:
    name: name of the optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  nesterov: bool = False
  momentum: float = 0.0
  jit_compile: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
77
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
class RMSPropConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


Hao Wu's avatar
Hao Wu committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@dataclasses.dataclass
class AdagradConfig(BaseOptimizerConfig):
  """Configuration for Adagrad optimizer.

  The attributes of this class match the arguments of
  tf.keras.optimizer.Adagrad.

  Attributes:
    name: name of the optimizer.
    initial_accumulator_value: A floating point value. Starting value for the
      accumulators, must be non-negative.
    epsilon: A small floating point value to avoid zero denominator.
  """
  name: str = "Adagrad"
  initial_accumulator_value: float = 0.1
  epsilon: float = 1e-07


Abdullah Rashwan's avatar
Abdullah Rashwan committed
116
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
117
class AdamConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
119
120
121
122
123
124
125
126
127
128
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
129
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
130
131
132
133
134
135
136
137
138
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
class AdamWeightDecayConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
140
141
142
143
144
145
146
147
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
148
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
149
150
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
Hongkun Yu's avatar
Hongkun Yu committed
151
      in weight decay.
152
    exclude_from_weight_decay: list[str], or None. List of weight names to not
Hongkun Yu's avatar
Hongkun Yu committed
153
      include in weight decay.
154
155
    gradient_clip_norm: A positive float. Clips the gradients to this maximum
      L2-norm. Default to 1.0.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
156
157
158
159
160
161
162
163
164
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
Hongkun Yu's avatar
Hongkun Yu committed
165
  gradient_clip_norm: float = 1.0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
166
167
168


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
169
class LAMBConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
170
171
172
173
174
175
176
177
178
179
180
181
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of
  tensorflow_addons.optimizers.LAMB.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
Hongkun Yu's avatar
Hongkun Yu committed
182
183
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
184
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
Hongkun Yu's avatar
Hongkun Yu committed
185
186
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
187
188
189
190
191
192
193
194
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
195
196
197


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
198
class EMAConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
199
200
201
202
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
203
204
205
    trainable_weights_only: 'bool', if True, only model trainable weights will
      be updated. Otherwise, all model weights will be updated. This mainly
      affects batch normalization parameters.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
206
207
208
209
210
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
Abdullah Rashwan's avatar
Abdullah Rashwan committed
211
  trainable_weights_only: bool = True
Abdullah Rashwan's avatar
Abdullah Rashwan committed
212
213
214
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
215
216
217
218
219
220
221
222


@dataclasses.dataclass
class LARSConfig(BaseOptimizerConfig):
  """Layer-wise adaptive rate scaling config.

  Attributes:
    name: 'str', name of the optimizer.
Hao Wu's avatar
Hao Wu committed
223
224
    momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
      the relevant direction and dampens oscillations. Defaults to 0.9.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
225
    eeta: `float` LARS coefficient as used in the paper. Default set to LARS
Hao Wu's avatar
Hao Wu committed
226
227
      coefficient from the paper. (eeta / weight_decay) determines the highest
      scaling factor in LARS..
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
228
229
230
    weight_decay_rate: `float` for weight decay.
    nesterov: 'boolean' for whether to use nesterov momentum.
    classic_momentum: `boolean` for whether to use classic (or popular)
Hao Wu's avatar
Hao Wu committed
231
232
233
234
235
236
237
238
239
      momentum. The learning rate is applied during momentum update in classic
      momentum, but after momentum for popular momentum.
    exclude_from_weight_decay: A list of `string` for variable screening, if any
      of the string appears in a variable's name, the variable will be excluded
      for computing weight decay. For example, one could specify the list like
      ['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
    exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
      layer adaptation. If it is None, it will be defaulted the same as
      exclude_from_weight_decay.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
240
241
242
243
244
245
246
247
248
  """
  name: str = "LARS"
  momentum: float = 0.9
  eeta: float = 0.001
  weight_decay_rate: float = 0.0
  nesterov: bool = False
  classic_momentum: bool = True
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269


@dataclasses.dataclass
class SLIDEConfig(BaseOptimizerConfig):
  """Configuration for SLIDE optimizer.

  Details coming soon.
  """
  name: str = "SLIDE"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  weight_decay_type: str = "inner"
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
  include_in_sparse_layer_adaptation: Optional[List[str]] = None
  sparse_layer_learning_rate: float = 0.1
  do_gradient_rescaling: bool = True
  norm_type: str = "layer"
  ratio_clip_norm: float = 1e5
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288


@dataclasses.dataclass
class AdafactorConfig(BaseOptimizerConfig):
  """Configuration for Adafactor optimizer.

  The attributes for this class matches the arguments of the Adafactor
  implementation.
  """
  name: str = "Adafactor"
  factored: bool = True
  multiply_by_parameter_scale: bool = True
  beta1: Optional[float] = None
  decay_rate: float = 0.8
  step_offset: int = 0
  clipping_threshold: float = 1.0
  min_dim_size_to_factor: int = 128
  epsilon1: float = 1e-30
  epsilon2: float = 1e-3