"vscode:/vscode.git/clone" did not exist on "af5647748a3046467e5b65e839d27393b04274d3"
optimizer_config.py 10.6 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Hongkun Yu's avatar
Hongkun Yu committed
14

Abdullah Rashwan's avatar
Abdullah Rashwan committed
15
16
17
18
19
20
21
22
"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
23
24
25
26
27
28
29
30
class BaseOptimizerConfig(base_config.Config):
  """Base optimizer config.

  Attributes:
    clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
      their L2 norm exceeds this value.
    clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
      their absolute value exceeds this value.
Hongkun Yu's avatar
Hongkun Yu committed
31
    global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
Hao Wu's avatar
Hao Wu committed
32
      clipped so that their global norm is no higher than this value
Abdullah Rashwan's avatar
Abdullah Rashwan committed
33
34
35
  """
  clipnorm: Optional[float] = None
  clipvalue: Optional[float] = None
Hongkun Yu's avatar
Hongkun Yu committed
36
  global_clipnorm: Optional[float] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
37
38
39
40


@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf.keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


Chen Qian's avatar
Chen Qian committed
57
58
59
60
61
62
63
64
65
66
67
68
69
# TODO(b/216129465): Merge this config with SGDConfig after the experimental
# optimizer graduates.
@dataclasses.dataclass
class SGDExperimentalConfig(BaseOptimizerConfig):
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of
  `tf.keras.optimizer.experimental.SGD`.

  Attributes:
    name: name of the optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
Chen Qian's avatar
Chen Qian committed
70
    jit_compile: if True, jit compile will be used.
Chen Qian's avatar
Chen Qian committed
71
72
73
74
75
76
77
  """
  name: str = "SGD"
  nesterov: bool = False
  momentum: float = 0.0
  jit_compile: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
class RMSPropConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


Hao Wu's avatar
Hao Wu committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@dataclasses.dataclass
class AdagradConfig(BaseOptimizerConfig):
  """Configuration for Adagrad optimizer.

  The attributes of this class match the arguments of
  tf.keras.optimizer.Adagrad.

  Attributes:
    name: name of the optimizer.
    initial_accumulator_value: A floating point value. Starting value for the
      accumulators, must be non-negative.
    epsilon: A small floating point value to avoid zero denominator.
  """
  name: str = "Adagrad"
  initial_accumulator_value: float = 0.1
  epsilon: float = 1e-07


Abdullah Rashwan's avatar
Abdullah Rashwan committed
117
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
118
class AdamConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
121
122
123
124
125
126
127
128
129
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
130
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
131
132
133
134
135
136
137
138
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


Chen Qian's avatar
Chen Qian committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
@dataclasses.dataclass
class AdamExperimentalConfig(BaseOptimizerConfig):
  """Configuration for experimental Adam optimizer.

  The attributes for this class matches the arguments of
  `tf.keras.optimizer.experimental.Adam`.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
      the paper "On the Convergence of Adam and beyond".
    jit_compile: if True, jit compile will be used.
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  jit_compile: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
163
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
164
class AdamWeightDecayConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
165
166
167
168
169
170
171
172
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
173
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
174
175
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
Hongkun Yu's avatar
Hongkun Yu committed
176
      in weight decay.
177
    exclude_from_weight_decay: list[str], or None. List of weight names to not
Hongkun Yu's avatar
Hongkun Yu committed
178
      include in weight decay.
179
180
    gradient_clip_norm: A positive float. Clips the gradients to this maximum
      L2-norm. Default to 1.0.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
181
182
183
184
185
186
187
188
189
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
Hongkun Yu's avatar
Hongkun Yu committed
190
  gradient_clip_norm: float = 1.0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
191
192
193


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
194
class LAMBConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
195
196
197
198
199
200
201
202
203
204
205
206
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of
  tensorflow_addons.optimizers.LAMB.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
Hongkun Yu's avatar
Hongkun Yu committed
207
208
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
209
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
Hongkun Yu's avatar
Hongkun Yu committed
210
211
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
212
213
214
215
216
217
218
219
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
220
221
222


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
223
class EMAConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
224
225
226
227
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
228
229
230
    trainable_weights_only: 'bool', if True, only model trainable weights will
      be updated. Otherwise, all model weights will be updated. This mainly
      affects batch normalization parameters.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
231
232
233
234
235
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
Abdullah Rashwan's avatar
Abdullah Rashwan committed
236
  trainable_weights_only: bool = True
Abdullah Rashwan's avatar
Abdullah Rashwan committed
237
238
239
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
240
241
242
243
244
245
246
247


@dataclasses.dataclass
class LARSConfig(BaseOptimizerConfig):
  """Layer-wise adaptive rate scaling config.

  Attributes:
    name: 'str', name of the optimizer.
Hao Wu's avatar
Hao Wu committed
248
249
    momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
      the relevant direction and dampens oscillations. Defaults to 0.9.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
250
    eeta: `float` LARS coefficient as used in the paper. Default set to LARS
Hao Wu's avatar
Hao Wu committed
251
252
      coefficient from the paper. (eeta / weight_decay) determines the highest
      scaling factor in LARS..
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
253
254
255
    weight_decay_rate: `float` for weight decay.
    nesterov: 'boolean' for whether to use nesterov momentum.
    classic_momentum: `boolean` for whether to use classic (or popular)
Hao Wu's avatar
Hao Wu committed
256
257
258
259
260
261
262
263
264
      momentum. The learning rate is applied during momentum update in classic
      momentum, but after momentum for popular momentum.
    exclude_from_weight_decay: A list of `string` for variable screening, if any
      of the string appears in a variable's name, the variable will be excluded
      for computing weight decay. For example, one could specify the list like
      ['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
    exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
      layer adaptation. If it is None, it will be defaulted the same as
      exclude_from_weight_decay.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
265
266
267
268
269
270
271
272
273
  """
  name: str = "LARS"
  momentum: float = 0.9
  eeta: float = 0.001
  weight_decay_rate: float = 0.0
  nesterov: bool = False
  classic_momentum: bool = True
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294


@dataclasses.dataclass
class SLIDEConfig(BaseOptimizerConfig):
  """Configuration for SLIDE optimizer.

  Details coming soon.
  """
  name: str = "SLIDE"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  weight_decay_type: str = "inner"
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
  include_in_sparse_layer_adaptation: Optional[List[str]] = None
  sparse_layer_learning_rate: float = 0.1
  do_gradient_rescaling: bool = True
  norm_type: str = "layer"
  ratio_clip_norm: float = 1e5
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313


@dataclasses.dataclass
class AdafactorConfig(BaseOptimizerConfig):
  """Configuration for Adafactor optimizer.

  The attributes for this class matches the arguments of the Adafactor
  implementation.
  """
  name: str = "Adafactor"
  factored: bool = True
  multiply_by_parameter_scale: bool = True
  beta1: Optional[float] = None
  decay_rate: float = 0.8
  step_offset: int = 0
  clipping_threshold: float = 1.0
  min_dim_size_to_factor: int = 128
  epsilon1: float = 1e-30
  epsilon2: float = 1e-3
314
315
  weight_decay: Optional[float] = None
  include_in_weight_decay: Optional[str] = None