"vscode:/vscode.git/clone" did not exist on "a8e3e285efe6cc7430059dfb97334a49426c8d54"
Commit 2560387f authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 342548143
parent e9057c4d
......@@ -14,9 +14,10 @@
# limitations under the License.
# ==============================================================================
"""Optimizer factory class."""
from typing import Union
from typing import Callable, Union
import gin
import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
......@@ -126,9 +127,12 @@ class OptimizerFactory(object):
return lr
@gin.configurable
def build_optimizer(
self, lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule,
float]):
self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
postprocessor: Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer] = None):
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
......@@ -138,6 +142,8 @@ class OptimizerFactory(object):
Args:
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
Returns:
tf.keras.optimizers.Optimizer instance.
......@@ -157,5 +163,10 @@ class OptimizerFactory(object):
if self._use_ema:
optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer, **self._ema_config.as_dict())
if postprocessor:
optimizer = postprocessor(optimizer)
assert isinstance(optimizer, tf.keras.optimizers.Optimizer), (
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
'{}'.format(optimizer))
return optimizer
......@@ -44,7 +44,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)
optimizer = opt_factory.build_optimizer(lr, postprocessor=lambda x: x)
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment