optimizer_config.py 5.51 KB
Newer Older
Abdullah Rashwan's avatar
Abdullah Rashwan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BaseOptimizerConfig(base_config.Config):
  """Base optimizer config.

  Attributes:
    clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
      their L2 norm exceeds this value.
    clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
      their absolute value exceeds this value.
  """
  clipnorm: Optional[float] = None
  clipvalue: Optional[float] = None


@dataclasses.dataclass
class SGDConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf.keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


Abdullah Rashwan's avatar
Abdullah Rashwan committed
55
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
56
class RMSPropConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
76
@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
77
class AdamConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
78
79
80
81
82
83
84
85
86
87
88
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
89
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
90
91
92
93
94
95
96
97
98
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
99
class AdamWeightDecayConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
102
103
104
105
106
107
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
108
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
109
110
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
Hongkun Yu's avatar
Hongkun Yu committed
111
      in weight decay.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
112
    include_in_weight_decay: list[str], or None. List of weight names to not
Hongkun Yu's avatar
Hongkun Yu committed
113
      include in weight decay.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
114
115
116
117
118
119
120
121
122
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
Hongkun Yu's avatar
Hongkun Yu committed
123
  gradient_clip_norm: float = 1.0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
124
125
126


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
127
class LAMBConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
128
129
130
131
132
133
134
135
136
137
138
139
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of
  tensorflow_addons.optimizers.LAMB.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
Hongkun Yu's avatar
Hongkun Yu committed
140
141
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
142
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
Hongkun Yu's avatar
Hongkun Yu committed
143
144
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
145
146
147
148
149
150
151
152
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
153
154
155


@dataclasses.dataclass
Abdullah Rashwan's avatar
Abdullah Rashwan committed
156
class EMAConfig(BaseOptimizerConfig):
Abdullah Rashwan's avatar
Abdullah Rashwan committed
157
158
159
160
161
162
163
164
165
166
167
168
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True