Commit 95d1b298 authored by Hao Wu's avatar Hao Wu Committed by A. Unique TensorFlower
Browse files

Add adagrad optimizer support.

PiperOrigin-RevId: 370551493
parent 71c7b7f9
...@@ -40,6 +40,7 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -40,6 +40,7 @@ class OptimizerConfig(oneof.OneOfConfig):
lamb: lamb optimizer. lamb: lamb optimizer.
rmsprop: rmsprop optimizer. rmsprop: rmsprop optimizer.
lars: lars optimizer. lars: lars optimizer.
adagrad: adagrad optimizer.
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
...@@ -48,6 +49,7 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -48,6 +49,7 @@ class OptimizerConfig(oneof.OneOfConfig):
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig() rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig() lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
adagrad: opt_cfg.AdagradConfig = opt_cfg.AdagradConfig()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -99,8 +101,8 @@ class OptimizationConfig(base_config.Config): ...@@ -99,8 +101,8 @@ class OptimizationConfig(base_config.Config):
Attributes: Attributes:
optimizer: optimizer oneof config. optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema: optional exponential moving average optimizer config, if specified, ema
ema optimizer will be used. optimizer will be used.
learning_rate: learning rate oneof config. learning_rate: learning rate oneof config.
warmup: warmup oneof config. warmup: warmup oneof config.
""" """
......
...@@ -75,6 +75,24 @@ class RMSPropConfig(BaseOptimizerConfig): ...@@ -75,6 +75,24 @@ class RMSPropConfig(BaseOptimizerConfig):
centered: bool = False centered: bool = False
@dataclasses.dataclass
class AdagradConfig(BaseOptimizerConfig):
"""Configuration for Adagrad optimizer.
The attributes of this class match the arguments of
tf.keras.optimizer.Adagrad.
Attributes:
name: name of the optimizer.
initial_accumulator_value: A floating point value. Starting value for the
accumulators, must be non-negative.
epsilon: A small floating point value to avoid zero denominator.
"""
name: str = "Adagrad"
initial_accumulator_value: float = 0.1
epsilon: float = 1e-07
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(BaseOptimizerConfig): class AdamConfig(BaseOptimizerConfig):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
...@@ -178,23 +196,22 @@ class LARSConfig(BaseOptimizerConfig): ...@@ -178,23 +196,22 @@ class LARSConfig(BaseOptimizerConfig):
Attributes: Attributes:
name: 'str', name of the optimizer. name: 'str', name of the optimizer.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
in the relevant direction and dampens oscillations. Defaults to 0.9. the relevant direction and dampens oscillations. Defaults to 0.9.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the coefficient from the paper. (eeta / weight_decay) determines the highest
highest scaling factor in LARS.. scaling factor in LARS..
weight_decay_rate: `float` for weight decay. weight_decay_rate: `float` for weight decay.
nesterov: 'boolean' for whether to use nesterov momentum. nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular) classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in momentum. The learning rate is applied during momentum update in classic
classic momentum, but after momentum for popular momentum. momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if exclude_from_weight_decay: A list of `string` for variable screening, if any
any of the string appears in a variable's name, the variable will be of the string appears in a variable's name, the variable will be excluded
excluded for computing weight decay. For example, one could specify for computing weight decay. For example, one could specify the list like
the list like ['batch_normalization', 'bias'] to exclude BN and bias ['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
from weight decay. exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but layer adaptation. If it is None, it will be defaulted the same as
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay. exclude_from_weight_decay.
""" """
name: str = "LARS" name: str = "LARS"
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Callable, Union from typing import Callable, Union
import gin import gin
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers import tensorflow_addons.optimizers as tfa_optimizers
...@@ -33,6 +32,7 @@ OPTIMIZERS_CLS = { ...@@ -33,6 +32,7 @@ OPTIMIZERS_CLS = {
'lamb': tfa_optimizers.LAMB, 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop, 'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS, 'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad,
} }
LR_CLS = { LR_CLS = {
......
...@@ -23,9 +23,8 @@ from official.modeling.optimization.configs import optimization_config ...@@ -23,9 +23,8 @@ from official.modeling.optimization.configs import optimization_config
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('sgd'), ('rmsprop'), @parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'),
('adam'), ('adamw'), ('lars'), ('adagrad'))
('lamb'), ('lars'))
def test_optimizers(self, optimizer_type): def test_optimizers(self, optimizer_type):
params = { params = {
'optimizer': { 'optimizer': {
...@@ -50,10 +49,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -50,10 +49,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
@parameterized.parameters( @parameterized.parameters((None, None), (1.0, None), (None, 1.0))
(None, None),
(1.0, None),
(None, 1.0))
def test_gradient_clipping(self, clipnorm, clipvalue): def test_gradient_clipping(self, clipnorm, clipvalue):
params = { params = {
'optimizer': { 'optimizer': {
...@@ -359,8 +355,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -359,8 +355,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
} }
} }
} }
expected_lr_step_values = [ expected_lr_step_values = [[0, 1.0], [1, 1.0], [40, 1. / 40.],
[0, 1.0], [1, 1.0], [40, 1. / 40.], [60, 1. / 60. * 0.8]] [60, 1. / 60. * 0.8]]
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment