Commit 3d0e12fd authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 428831391
parent f2f7e39c
...@@ -48,6 +48,8 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -48,6 +48,8 @@ class OptimizerConfig(oneof.OneOfConfig):
sgd_experimental: opt_cfg.SGDExperimentalConfig = ( sgd_experimental: opt_cfg.SGDExperimentalConfig = (
opt_cfg.SGDExperimentalConfig()) opt_cfg.SGDExperimentalConfig())
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adam_experimental: opt_cfg.AdamExperimentalConfig = (
opt_cfg.AdamExperimentalConfig())
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig() rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
......
...@@ -67,6 +67,7 @@ class SGDExperimentalConfig(BaseOptimizerConfig): ...@@ -67,6 +67,7 @@ class SGDExperimentalConfig(BaseOptimizerConfig):
name: name of the optimizer. name: name of the optimizer.
nesterov: nesterov for SGD optimizer. nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer. momentum: momentum for SGD optimizer.
jit_compile: if True, jit compile will be used.
""" """
name: str = "SGD" name: str = "SGD"
nesterov: bool = False nesterov: bool = False
...@@ -135,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig): ...@@ -135,6 +136,30 @@ class AdamConfig(BaseOptimizerConfig):
amsgrad: bool = False amsgrad: bool = False
@dataclasses.dataclass
class AdamExperimentalConfig(BaseOptimizerConfig):
"""Configuration for experimental Adam optimizer.
The attributes for this class matches the arguments of
`tf.keras.optimizer.experimental.Adam`.
Attributes:
name: name of the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
jit_compile: if True, jit compile will be used.
"""
name: str = "Adam"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
jit_compile: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamWeightDecayConfig(BaseOptimizerConfig): class AdamWeightDecayConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay. """Configuration for Adam optimizer with weight decay.
......
...@@ -30,6 +30,7 @@ OPTIMIZERS_CLS = { ...@@ -30,6 +30,7 @@ OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD, 'sgd': tf.keras.optimizers.SGD,
'sgd_experimental': tf.keras.optimizers.experimental.SGD, 'sgd_experimental': tf.keras.optimizers.experimental.SGD,
'adam': tf.keras.optimizers.Adam, 'adam': tf.keras.optimizers.Adam,
'adam_experimental': tf.keras.optimizers.experimental.Adam,
'adamw': nlp_optimization.AdamWeightDecay, 'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB, 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop, 'rmsprop': tf.keras.optimizers.RMSprop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment