optimizer_config.py 5.81 KB
Newer Older
Abdullah Rashwan's avatar
Abdullah Rashwan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
25
26
27
28
29
30
31
class BaseOptimizerConfig(base_config.Config):
  """Base optimizer config.

  Attributes:
    clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
      their L2 norm exceeds this value.
    clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
      their absolute value exceeds this value.
Hongkun Yu's avatar
Hongkun Yu committed
32
33
    global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
        clipped so that their global norm is no higher than this value
Abdullah Rashwan's avatar
Abdullah Rashwan committed
34
35
36
  """
  clipnorm: Optional[float] = None
  clipvalue: Optional[float] = None
Hongkun Yu's avatar
Hongkun Yu committed
37
  global_clipnorm: Optional[float] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
38
39
40
41


@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf.keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


Abdullah Rashwan's avatar
Abdullah Rashwan committed
58
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
59
class RMSPropConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
79
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
80
class AdamConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
81
82
83
84
85
86
87
88
89
90
91
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
92
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
93
94
95
96
97
98
99
100
101
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
102
class AdamWeightDecayConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
103
104
105
106
107
108
109
110
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
111
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
112
113
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
Hongkun Yu's avatar
Hongkun Yu committed
114
      in weight decay.
115
    exclude_from_weight_decay: list[str], or None. List of weight names to not
Hongkun Yu's avatar
Hongkun Yu committed
116
      include in weight decay.
117
118
    gradient_clip_norm: A positive float. Clips the gradients to this maximum
      L2-norm. Default to 1.0.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
119
120
121
122
123
124
125
126
127
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
Hongkun Yu's avatar
Hongkun Yu committed
128
  gradient_clip_norm: float = 1.0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
129
130
131


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
132
class LAMBConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
133
134
135
136
137
138
139
140
141
142
143
144
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of
  tensorflow_addons.optimizers.LAMB.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
Hongkun Yu's avatar
Hongkun Yu committed
145
146
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
147
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
Hongkun Yu's avatar
Hongkun Yu committed
148
149
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
150
151
152
153
154
155
156
157
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
158
159
160


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
161
class EMAConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
162
163
164
165
166
167
168
169
170
171
172
173
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True