optimizer_config.py 5.08 KB
Newer Older
Abdullah Rashwan's avatar
Abdullah Rashwan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataclasses for optimizer configs."""
from typing import List, Optional

import dataclasses
from official.modeling.hyperparams import base_config


@dataclasses.dataclass
class SGDConfig(base_config.Config):
  """Configuration for SGD optimizer.

  The attributes for this class matches the arguments of tf.keras.optimizer.SGD.

  Attributes:
    name: name of the optimizer.
    decay: decay rate for SGD optimizer.
    nesterov: nesterov for SGD optimizer.
    momentum: momentum for SGD optimizer.
  """
  name: str = "SGD"
  decay: float = 0.0
  nesterov: bool = False
  momentum: float = 0.0


Abdullah Rashwan's avatar
Abdullah Rashwan committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
  """Configuration for RMSProp optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizers.RMSprop.

  Attributes:
    name: name of the optimizer.
    rho: discounting factor for RMSprop optimizer.
    momentum: momentum for RMSprop optimizer.
    epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
    centered: Whether to normalize gradients or not.
  """
  name: str = "RMSprop"
  rho: float = 0.9
  momentum: float = 0.0
  epsilon: float = 1e-7
  centered: bool = False


Abdullah Rashwan's avatar
Abdullah Rashwan committed
62
63
64
65
66
67
68
69
70
71
72
73
74
@dataclasses.dataclass
class AdamConfig(base_config.Config):
  """Configuration for Adam optimizer.

  The attributes for this class matches the arguments of
  tf.keras.optimizer.Adam.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in Adam optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
75
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  """
  name: str = "Adam"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False


@dataclasses.dataclass
class AdamWeightDecayConfig(base_config.Config):
  """Configuration for Adam optimizer with weight decay.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in the optimizer.
    amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
Hongkun Yu's avatar
Hongkun Yu committed
94
      the paper "On the Convergence of Adam and beyond".
Abdullah Rashwan's avatar
Abdullah Rashwan committed
95
96
    weight_decay_rate: float. Weight decay rate. Default to 0.
    include_in_weight_decay: list[str], or None. List of weight names to include
Hongkun Yu's avatar
Hongkun Yu committed
97
      in weight decay.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
98
    include_in_weight_decay: list[str], or None. List of weight names to not
Hongkun Yu's avatar
Hongkun Yu committed
99
      include in weight decay.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
100
101
102
103
104
105
106
107
108
  """
  name: str = "AdamWeightDecay"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-07
  amsgrad: bool = False
  weight_decay_rate: float = 0.0
  include_in_weight_decay: Optional[List[str]] = None
  exclude_from_weight_decay: Optional[List[str]] = None
Hongkun Yu's avatar
Hongkun Yu committed
109
  gradient_clip_norm: float = 1.0
Abdullah Rashwan's avatar
Abdullah Rashwan committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125


@dataclasses.dataclass
class LAMBConfig(base_config.Config):
  """Configuration for LAMB optimizer.

  The attributes for this class matches the arguments of
  tensorflow_addons.optimizers.LAMB.

  Attributes:
    name: name of the optimizer.
    beta_1: decay rate for 1st order moments.
    beta_2: decay rate for 2st order moments.
    epsilon: epsilon value used for numerical stability in LAMB optimizer.
    weight_decay_rate: float. Weight decay rate. Default to 0.
    exclude_from_weight_decay: List of regex patterns of variables excluded from
Hongkun Yu's avatar
Hongkun Yu committed
126
127
      weight decay. Variables whose name contain a substring matching the
      pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
128
    exclude_from_layer_adaptation: List of regex patterns of variables excluded
Hongkun Yu's avatar
Hongkun Yu committed
129
130
      from layer adaptation. Variables whose name contain a substring matching
      the pattern will be excluded.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
131
132
133
134
135
136
137
138
  """
  name: str = "LAMB"
  beta_1: float = 0.9
  beta_2: float = 0.999
  epsilon: float = 1e-6
  weight_decay_rate: float = 0.0
  exclude_from_weight_decay: Optional[List[str]] = None
  exclude_from_layer_adaptation: Optional[List[str]] = None
Abdullah Rashwan's avatar
Abdullah Rashwan committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154


@dataclasses.dataclass
class EMAConfig(base_config.Config):
  """Exponential moving average optimizer config.

  Attributes:
    name: 'str', name of the optimizer.
    average_decay: 'float', average decay value.
    start_step: 'int', start step to apply moving average.
    dynamic_decay: 'bool', whether to apply dynamic decay or not.
  """
  name: str = "ExponentialMovingAverage"
  average_decay: float = 0.99
  start_step: int = 0
  dynamic_decay: bool = True