Commit fdbf4946 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[optim] Expose registry for users.

PiperOrigin-RevId: 402689877
parent eb0ed8a1
......@@ -21,3 +21,4 @@ from official.modeling.optimization.configs.optimizer_config import *
from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
from official.modeling.optimization.lr_schedule import *
from official.modeling.optimization.optimizer_factory import OptimizerFactory
from official.modeling.optimization.optimizer_factory import register_optimizer_cls
......@@ -56,6 +56,22 @@ WARMUP_CLS = {
}
def register_optimizer_cls(
key: str, optimizer_config_cls: tf.keras.optimizers.Optimizer):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
configs.optimization_config to be used with OptimizerFactory.
Args:
key: A string to that the optimizer_config_cls is registered with.
optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
"""
if key in OPTIMIZERS_CLS:
raise ValueError('%s already registered in OPTIMIZER_CLS.' % key)
OPTIMIZERS_CLS[key] = optimizer_config_cls
class OptimizerFactory:
"""Optimizer factory class.
......
......@@ -427,5 +427,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for step, value in expected_lr_step_values:
self.assertAlmostEqual(lr(step).numpy(), value)
class OptimizerFactoryRegistryTest(tf.test.TestCase):
def test_registry(self):
class MyClass():
pass
optimizer_factory.register_optimizer_cls('test', MyClass)
self.assertIn('test', optimizer_factory.OPTIMIZERS_CLS)
with self.assertRaisesRegex(ValueError, 'test already registered.*'):
optimizer_factory.register_optimizer_cls('test', MyClass)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment