"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c7617e482a522173ea6f922223aa010058552af8"
Commit 0a6f6426 authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 316962972
parent 57e7ca73
...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig): ...@@ -39,12 +39,14 @@ class OptimizerConfig(oneof.OneOfConfig):
adam: adam optimizer config. adam: adam optimizer config.
adamw: adam with weight decay. adamw: adam with weight decay.
lamb: lamb optimizer. lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
""" """
type: Optional[str] = None type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig() sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig() adam: opt_cfg.AdamConfig = opt_cfg.AdamConfig()
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig() adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig() lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config): ...@@ -40,6 +40,29 @@ class SGDConfig(base_config.Config):
momentum: float = 0.0 momentum: float = 0.0
@dataclasses.dataclass
class RMSPropConfig(base_config.Config):
"""Configuration for RMSProp optimizer.
The attributes for this class matches the arguments of
tf.keras.optimizers.RMSprop.
Attributes:
name: name of the optimizer.
learning_rate: learning_rate for RMSprop optimizer.
rho: discounting factor for RMSprop optimizer.
momentum: momentum for RMSprop optimizer.
epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
centered: Whether to normalize gradients or not.
"""
name: str = "RMSprop"
learning_rate: float = 0.001
rho: float = 0.9
momentum: float = 0.0
epsilon: float = 1e-7
centered: bool = False
@dataclasses.dataclass @dataclasses.dataclass
class AdamConfig(base_config.Config): class AdamConfig(base_config.Config):
"""Configuration for Adam optimizer. """Configuration for Adam optimizer.
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Union from typing import Union
import tensorflow as tf import tensorflow as tf
...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = { ...@@ -29,7 +28,8 @@ OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD, 'sgd': tf.keras.optimizers.SGD,
'adam': tf.keras.optimizers.Adam, 'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay, 'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB 'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop
} }
LR_CLS = { LR_CLS = {
......
...@@ -15,84 +15,37 @@ ...@@ -15,84 +15,37 @@
# ============================================================================== # ==============================================================================
"""Tests for optimizer_factory.py.""" """Tests for optimizer_factory.py."""
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import optimizer_factory from official.modeling.optimization import optimizer_factory
from official.modeling.optimization.configs import optimization_config from official.modeling.optimization.configs import optimization_config
from official.nlp import optimization as nlp_optimization
class OptimizerFactoryTest(tf.test.TestCase):
def test_sgd_optimizer(self):
params = {
'optimizer': {
'type': 'sgd',
'sgd': {'learning_rate': 0.1, 'momentum': 0.9}
}
}
expected_optimizer_config = {
'name': 'SGD',
'learning_rate': 0.1,
'decay': 0.0,
'momentum': 0.9,
'nesterov': False
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.SGD)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_optimizer(self):
# Define adam optimizer with default values.
params = {
'optimizer': {
'type': 'adam'
}
}
expected_optimizer_config = tf.keras.optimizers.Adam().get_config()
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam) class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_adam_weight_decay_optimizer(self): @parameterized.parameters(
('sgd'),
('rmsprop'),
('adam'),
('adamw'),
('lamb'))
def test_optimizers(self, optimizer_type):
params = { params = {
'optimizer': { 'optimizer': {
'type': 'adamw' 'type': optimizer_type
} }
} }
expected_optimizer_config = nlp_optimization.AdamWeightDecay().get_config() optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
opt_config = optimization_config.OptimizationConfig(params) expected_optimizer_config = optimizer_cls().get_config()
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, nlp_optimization.AdamWeightDecay)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_lamb_optimizer(self):
params = {
'optimizer': {
'type': 'lamb'
}
}
expected_optimizer_config = tfa_optimizers.LAMB().get_config()
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr) optimizer = opt_factory.build_optimizer(lr)
self.assertIsInstance(optimizer, tfa_optimizers.LAMB) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_stepwise_lr_schedule(self): def test_stepwise_lr_schedule(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment