Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c50daa27
Commit
c50daa27
authored
Oct 12, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Oct 12, 2021
Browse files
[optim] Expose registry for users.
PiperOrigin-RevId: 402689877
parent
66e25b31
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
0 deletions
+29
-0
official/modeling/optimization/__init__.py
official/modeling/optimization/__init__.py
+1
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+16
-0
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+12
-0
No files found.
official/modeling/optimization/__init__.py
View file @
c50daa27
...
...
@@ -21,3 +21,4 @@ from official.modeling.optimization.configs.optimizer_config import *
from
official.modeling.optimization.ema_optimizer
import
ExponentialMovingAverage
from
official.modeling.optimization.lr_schedule
import
*
from
official.modeling.optimization.optimizer_factory
import
OptimizerFactory
from
official.modeling.optimization.optimizer_factory
import
register_optimizer_cls
official/modeling/optimization/optimizer_factory.py
View file @
c50daa27
...
...
@@ -56,6 +56,22 @@ WARMUP_CLS = {
}
def
register_optimizer_cls
(
key
:
str
,
optimizer_config_cls
:
tf
.
keras
.
optimizers
.
Optimizer
):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
configs.optimization_config to be used with OptimizerFactory.
Args:
key: A string to that the optimizer_config_cls is registered with.
optimizer_config_cls: A class which inherits tf.keras.optimizers.Optimizer.
"""
if
key
in
OPTIMIZERS_CLS
:
raise
ValueError
(
'%s already registered in OPTIMIZER_CLS.'
%
key
)
OPTIMIZERS_CLS
[
key
]
=
optimizer_config_cls
class
OptimizerFactory
:
"""Optimizer factory class.
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
c50daa27
...
...
@@ -427,5 +427,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
class
OptimizerFactoryRegistryTest
(
tf
.
test
.
TestCase
):
def
test_registry
(
self
):
class
MyClass
():
pass
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
self
.
assertIn
(
'test'
,
optimizer_factory
.
OPTIMIZERS_CLS
)
with
self
.
assertRaisesRegex
(
ValueError
,
'test already registered.*'
):
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment