Commit 831b3dfd authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 320495344
parent 5d09fb27
......@@ -20,6 +20,20 @@ import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ConstantLrConfig(base_config.Config):
"""Configuration for constant learning rate.
This class is a containers for the constant learning rate decay configs.
Attributes:
name: The name of the learning rate schedule. Defaults to Constant.
learning_rate: A float. The learning rate. Defaults to 0.1.
"""
name: str = 'Constant'
learning_rate: float = 0.1
@dataclasses.dataclass
class StepwiseLrConfig(base_config.Config):
"""Configuration for stepwise learning rate decay.
......
......@@ -55,12 +55,14 @@ class LrConfig(oneof.OneOfConfig):
Attributes:
type: 'str', type of lr schedule to be used, on the of fields below.
constant: constant learning rate config.
stepwise: stepwise learning rate config.
exponential: exponential learning rate config.
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
"""
type: Optional[str] = None
constant: lr_cfg.ConstantLrConfig = lr_cfg.ConstantLrConfig()
stepwise: lr_cfg.StepwiseLrConfig = lr_cfg.StepwiseLrConfig()
exponential: lr_cfg.ExponentialLrConfig = lr_cfg.ExponentialLrConfig()
polynomial: lr_cfg.PolynomialLrConfig = lr_cfg.PolynomialLrConfig()
......
......@@ -28,13 +28,11 @@ class SGDConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for SGD optimizer.
decay: decay rate for SGD optimizer.
nesterov: nesterov for SGD optimizer.
momentum: momentum for SGD optimizer.
"""
name: str = "SGD"
learning_rate: float = 0.01
decay: float = 0.0
nesterov: bool = False
momentum: float = 0.0
......@@ -49,14 +47,12 @@ class RMSPropConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
......@@ -72,7 +68,6 @@ class AdamConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
......@@ -80,7 +75,6 @@ class AdamConfig(base_config.Config):
the paper "On the Convergence of Adam and beyond".
"""
name: str = "Adam"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -93,7 +87,6 @@ class AdamWeightDecayConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for the optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
......@@ -106,7 +99,6 @@ class AdamWeightDecayConfig(base_config.Config):
include in weight decay.
"""
name: str = "AdamWeightDecay"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-07
......@@ -125,7 +117,6 @@ class LAMBConfig(base_config.Config):
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for Adam optimizer.
beta_1: decay rate for 1st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in LAMB optimizer.
......@@ -139,7 +130,6 @@ class LAMBConfig(base_config.Config):
be excluded.
"""
name: str = "LAMB"
learning_rate: float = 0.001
beta_1: float = 0.9
beta_2: float = 0.999
epsilon: float = 1e-6
......
......@@ -60,7 +60,7 @@ class OptimizerFactory(object):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -88,12 +88,15 @@ class OptimizerFactory(object):
self._optimizer_config = config.optimizer.get()
self._optimizer_type = config.optimizer.type
if self._optimizer_config is None:
if self._optimizer_type is None:
raise ValueError('Optimizer type must be specified')
self._lr_config = config.learning_rate.get()
self._lr_type = config.learning_rate.type
if self._lr_type is None:
raise ValueError('Learning rate type must be specified')
self._warmup_config = config.warmup.get()
self._warmup_type = config.warmup.type
......@@ -101,18 +104,15 @@ class OptimizerFactory(object):
"""Build learning rate.
Builds learning rate from config. Learning rate schedule is built according
to the learning rate config. If there is no learning rate config, optimizer
learning rate is returned.
to the learning rate config. If learning rate type is consant,
lr_config.learning_rate is returned.
Returns:
tf.keras.optimizers.schedules.LearningRateSchedule instance. If no
learning rate schedule defined, optimizer_config.learning_rate is
returned.
tf.keras.optimizers.schedules.LearningRateSchedule instance. If
learning rate type is consant, lr_config.learning_rate is returned.
"""
# TODO(arashwan): Explore if we want to only allow explicit const lr sched.
if not self._lr_config:
lr = self._optimizer_config.learning_rate
if self._lr_type == 'constant':
lr = self._lr_config.learning_rate
else:
lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
......
......@@ -35,10 +35,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
......@@ -48,11 +55,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_missing_types(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'momentum': 0.9}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
params = {
'learning_rate': {
'type': 'stepwise',
'stepwise': {'boundaries': [10000, 20000],
'values': [0.1, 0.01, 0.001]}
}
}
with self.assertRaises(ValueError):
optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params))
def test_stepwise_lr_schedule(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -79,7 +107,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......@@ -112,7 +140,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'exponential',
......@@ -142,7 +170,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'polynomial',
......@@ -166,7 +194,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'cosine',
......@@ -192,7 +220,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
},
'warmup': {
'type': 'linear',
......@@ -216,7 +250,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
'sgd': {'momentum': 0.9}
},
'learning_rate': {
'type': 'stepwise',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment