Commit 98f2d335 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 438823508
parent db3eead9
......@@ -57,8 +57,8 @@ WARMUP_CLS = {
}
def register_optimizer_cls(
key: str, optimizer_config_cls: tf.keras.optimizers.Optimizer):
def register_optimizer_cls(key: str,
optimizer_config_cls: tf.keras.optimizers.Optimizer):
"""Register customize optimizer cls.
The user will still need to subclass data classes in
......@@ -156,9 +156,12 @@ class OptimizerFactory:
def build_optimizer(
self,
lr: Union[tf.keras.optimizers.schedules.LearningRateSchedule, float],
gradient_aggregator: Optional[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]] = None,
gradient_transformers: Optional[List[Callable[
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor, tf.Tensor]]
]]] = None,
[List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
tf.Tensor]]]]] = None,
postprocessor: Optional[Callable[[tf.keras.optimizers.Optimizer],
tf.keras.optimizers.Optimizer]] = None):
"""Build optimizer.
......@@ -170,6 +173,7 @@ class OptimizerFactory:
Args:
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_aggregator: Optional function to overwrite gradient aggregation.
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
......@@ -193,6 +197,8 @@ class OptimizerFactory:
del optimizer_dict['global_clipnorm']
optimizer_dict['learning_rate'] = lr
if gradient_aggregator is not None:
optimizer_dict['gradient_aggregator'] = gradient_aggregator
if gradient_transformers is not None:
optimizer_dict['gradient_transformers'] = gradient_transformers
......
......@@ -49,6 +49,39 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(optimizer, optimizer_cls)
self.assertEqual(expected_optimizer_config, optimizer.get_config())
def test_gradient_aggregator(self):
params = {
'optimizer': {
'type': 'adam',
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 1.0
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
# Dummy function to zero out gradients.
zero_grads = lambda gv: [(tf.zeros_like(g), v) for g, v in gv]
optimizer = opt_factory.build_optimizer(lr, gradient_aggregator=zero_grads)
var0 = tf.Variable([1.0, 2.0])
var1 = tf.Variable([3.0, 4.0])
grads0 = tf.constant([1.0, 1.0])
grads1 = tf.constant([1.0, 1.0])
grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
optimizer.apply_gradients(grads_and_vars)
self.assertAllClose(np.array([1.0, 2.0]), var0.numpy())
self.assertAllClose(np.array([3.0, 4.0]), var1.numpy())
@parameterized.parameters((None, None), (1.0, None), (None, 1.0))
def test_gradient_clipping(self, clipnorm, clipvalue):
params = {
......@@ -418,7 +451,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
expected_lr_step_values = [[0, 0.0], [5000, 1e-4/2.0], [10000, 1e-4],
expected_lr_step_values = [[0, 0.0], [5000, 1e-4 / 2.0], [10000, 1e-4],
[20000, 9.994863e-05], [499999, 5e-05]]
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
......@@ -434,10 +467,12 @@ class OptimizerFactoryRegistryTest(tf.test.TestCase):
class MyClass():
pass
optimizer_factory.register_optimizer_cls('test', MyClass)
self.assertIn('test', optimizer_factory.OPTIMIZERS_CLS)
with self.assertRaisesRegex(ValueError, 'test already registered.*'):
optimizer_factory.register_optimizer_cls('test', MyClass)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment