Commit 2560387f authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342548143
parent e9057c4d
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Optimizer factory class.""" """Optimizer factory class."""
from typing import Union from typing import Callable, Union
import gin
import tensorflow as tf import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers import tensorflow_addons.optimizers as tfa_optimizers
...@@ -126,9 +127,12 @@ class OptimizerFactory(object): ...@@ -126,9 +127,12 @@ class OptimizerFactory(object):
return lr return lr
@gin.configurable
def build_optimizer( def build_optimizer(
self, lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, self,
float]): lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
postprocessor: Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer] = None):
"""Build optimizer. """Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds Builds optimizer from config. It takes learning rate as input, and builds
...@@ -138,6 +142,8 @@ class OptimizerFactory(object): ...@@ -138,6 +142,8 @@ class OptimizerFactory(object):
Args: Args:
lr: A floating point value, or a lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance. tf.keras.optimizers.schedules.LearningRateSchedule instance.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
Returns: Returns:
tf.keras.optimizers.Optimizer instance. tf.keras.optimizers.Optimizer instance.
...@@ -157,5 +163,10 @@ class OptimizerFactory(object): ...@@ -157,5 +163,10 @@ class OptimizerFactory(object):
if self._use_ema: if self._use_ema:
optimizer = ema_optimizer.ExponentialMovingAverage( optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict()) optimizer, **self._ema_config.as_dict())
if postprocessor:
optimizer = postprocessor(optimizer)
assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
'{}'.format(optimizer))
return optimizer return optimizer
...@@ -44,7 +44,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -44,7 +44,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr) optimizer = opt_factory.build_optimizer(lr, postprocessor=lambda x: x)
self.assertIsInstance(optimizer, optimizer_cls) self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config()) self.assertEqual(expected_optimizer_config, optimizer.get_config())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment