"vscode:/vscode.git/clone" did not exist on "912c44faa9b6a11fe60f1f33c665287078f0531e"
Commit 56761bcb authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add adamW experimental optimzier.

PiperOrigin-RevId: 452047178
parent 9985d123
......@@ -51,6 +51,8 @@ class OptimizerConfig(oneof.OneOfConfig):
adam_experimental: opt_cfg.AdamExperimentalConfig = (
opt_cfg.AdamExperimentalConfig())
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
adamw_experimental: opt_cfg.AdamWeightDecayExperimentalConfig = (
opt_cfg.AdamWeightDecayExperimentalConfig())
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
......
......@@ -190,6 +190,32 @@ class AdamWeightDecayConfig(BaseOptimizerConfig):
gradient_clip_norm: float = 1.0
@dataclasses.dataclass
class AdamWeightDecayExperimentalConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer with weight decay.
Attributes:
name: name of the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
weight_decay: float. Weight decay rate. Default to 0.
global_clipnorm: A positive float. Clips the gradients to this maximum
L2-norm. Default to 1.0.
jit_compile: if True, jit compile will be used.
"""
name: str = "AdamWeightDecayExperimental"
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
amsgrad: bool = False
weight_decay: float = 0.0
global_clipnorm: float = 1.0
jit_compile: bool = False
@dataclasses.dataclass
class LAMBConfig(BaseOptimizerConfig):
"""Configuration for LAMB optimizer.
......
......@@ -33,6 +33,7 @@ OPTIMIZERS_CLS = {
'adam': tf.keras.optimizers.Adam,
# TODO(chenmoneygithub): experimental.Adam
'adamw': legacy_adamw.AdamWeightDecay,
'adamw_experimental': tf.keras.optimizers.experimental.AdamW,
'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment