Commit c50daa27 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[optim] Expose registry for users.

PiperOrigin-RevId: 402689877
parent 66e25b31
...@@ -21,3 +21,4 @@ from official.modeling.optimization.configs.optimizer_config import * ...@@ -21,3 +21,4 @@ from official.modeling.optimization.configs.optimizer_config import *
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
from official.modeling.optimization.lr_schedule import * from official.modeling.optimization.lr_schedule import *
from official.modeling.optimization.optimizer_factory import OptimizerFactory from official.modeling.optimization.optimizer_factory import OptimizerFactory
from official.modeling.optimization.optimizer_factory import register_optimizer_cls
...@@ -56,6 +56,22 @@ WARMUP_CLS = { ...@@ -56,6 +56,22 @@ WARMUP_CLS = {
} }
def register_optimizer_cls(
key: str, optimizer_config_cls: tf.keras.optimizers.Optimizer):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
configs.optimization_config to be used with OptimizerFactory.
Args:
key: A string to that the optimizer_config_cls is registered with.
optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
"""
if key in OPTIMIZERS_CLS:
raise ValueError('%s already registered in OPTIMIZER_CLS.' % key)
OPTIMIZERS_CLS[key] = optimizer_config_cls
class OptimizerFactory: class OptimizerFactory:
"""Optimizer factory class. """Optimizer factory class.
......
...@@ -427,5 +427,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -427,5 +427,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values: for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value) self.assertAlmostEqual(lr(step).numpy(), value)
class OptimizerFactoryRegistryTest(tf.test.TestCase):
def test_registry(self):
class MyClass():
pass
optimizer_factory.register_optimizer_cls('test', MyClass)
self.assertIn('test', optimizer_factory.OPTIMIZERS_CLS)
with self.assertRaisesRegex(ValueError, 'test already registered.*'):
optimizer_factory.register_optimizer_cls('test', MyClass)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment