Commit 68e5b49c authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Add an argument `use_legacy_optimizer` to build_optimizer to explicitly...

Add an argument `use_legacy_optimizer` to build_optimizer to explicitly control which optimizer to return.

This is to get compatible for an incoming Keras optimizer migration.

PiperOrigin-RevId: 471276487
parent 56c8a503
......@@ -27,21 +27,34 @@ from official.modeling.optimization import legacy_adamw
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.SGD,
# Optimizer CLS to be used in both legacy and new path.
SHARED_OPTIMIZERS = {
# TODO(chenmoneygithub): experimental.SGD
'adam': tf.keras.optimizers.Adam,
# TODO(chenmoneygithub): experimental.Adam
'adamw': legacy_adamw.AdamWeightDecay,
'adamw_experimental': tf.keras.optimizers.experimental.AdamW,
'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS,
'adagrad': tf.keras.optimizers.Adagrad,
'slide': slide_optimizer.SLIDE,
'adafactor': adafactor_optimizer.Adafactor,
}
LEGACY_OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.legacy.SGD,
'adam': tf.keras.optimizers.legacy.Adam,
'rmsprop': tf.keras.optimizers.legacy.RMSprop,
'adagrad': tf.keras.optimizers.legacy.Adagrad,
}
LEGACY_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
NEW_OPTIMIZERS_CLS = {
'sgd': tf.keras.optimizers.experimental.SGD,
'adam': tf.keras.optimizers.experimental.Adam,
'rmsprop': tf.keras.optimizers.experimental.RMSprop,
'adagrad': tf.keras.optimizers.experimental.Adagrad,
}
NEW_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
LR_CLS = {
'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
'polynomial': lr_schedule.PolynomialDecayWithOffset,
......@@ -60,7 +73,12 @@ WARMUP_CLS = {
def register_optimizer_cls(key: str,
optimizer_config_cls: tf.keras.optimizers.Optimizer):
optimizer_config_cls: Union[
tf.keras.optimizers.Optimizer,
tf.keras.optimizers.legacy.Optimizer,
tf.keras.optimizers.experimental.Optimizer
],
use_legacy_optimizer: bool = True):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
......@@ -69,10 +87,16 @@ def register_optimizer_cls(key: str,
Args:
key: A string to that the optimizer_config_cls is registered with.
optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
"""
if key in OPTIMIZERS_CLS:
raise ValueError('%s already registered in OPTIMIZER_CLS.' % key)
OPTIMIZERS_CLS[key] = optimizer_config_cls
if use_legacy_optimizer:
if key in LEGACY_OPTIMIZERS_CLS:
raise ValueError('%s already registered in LEGACY_OPTIMIZERS_CLS.' % key)
LEGACY_OPTIMIZERS_CLS[key] = optimizer_config_cls
else:
if key in NEW_OPTIMIZERS_CLS:
raise ValueError('%s already registered in NEW_OPTIMIZERS_CLS.' % key)
NEW_OPTIMIZERS_CLS[key] = optimizer_config_cls
class OptimizerFactory:
......@@ -168,7 +192,8 @@ class OptimizerFactory:
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]]] = None,
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer]] = None):
tf.keras.optimizers.Optimizer]] = None,
use_legacy_optimizer: bool = True):
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
......@@ -186,9 +211,10 @@ class OptimizerFactory:
global_clipnorm should not be set when gradient_transformers is passed.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
Returns:
`tf.keras.optimizers.Optimizer` or
`tf.keras.optimizers.legacy.Optimizer` or
`tf.keras.optimizers.experimental.Optimizer` instance.
"""
......@@ -207,9 +233,16 @@ class OptimizerFactory:
if gradient_transformers is not None:
optimizer_dict['gradient_transformers'] = gradient_transformers
optimizer = OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
if use_legacy_optimizer:
optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
else:
optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
if self._use_ema:
if not use_legacy_optimizer:
raise ValueError(
'EMA can only work with the legacy optimizer, please set '
'`use_legacy_optimizer=True`.')
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
if postprocessor:
......
......@@ -37,7 +37,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
optimizer_cls = optimizer_factory.OPTIMIZERS_CLS[optimizer_type]
optimizer_cls = optimizer_factory.LEGACY_OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
......@@ -49,6 +49,33 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
@parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'),
('lars'), ('adagrad'))
def test_new_optimizers(self, optimizer_type):
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
optimizer_cls = optimizer_factory.NEW_OPTIMIZERS_CLS[optimizer_type]
expected_optimizer_config = optimizer_cls().get_config()
expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(
lr, postprocessor=lambda x: x, use_legacy_optimizer=False)
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_gradient_aggregator(self):
params = {
'optimizer': {
......@@ -491,7 +518,7 @@ class OptimizerFactoryRegistryTest(tf.test.TestCase):
pass
optimizer_factory.register_optimizer_cls('test', MyClass)
self.assertIn('test', optimizer_factory.OPTIMIZERS_CLS)
self.assertIn('test', optimizer_factory.LEGACY_OPTIMIZERS_CLS)
with self.assertRaisesRegex(ValueError, 'test already registered.*'):
optimizer_factory.register_optimizer_cls('test', MyClass)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment