Commit 00f96640 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds check of existence of the experimental optimizer package.

PiperOrigin-RevId: 450601824
parent 55b43f27
...@@ -213,10 +213,17 @@ class OptimizerFactory: ...@@ -213,10 +213,17 @@ class OptimizerFactory:
optimizer, **self._ema_config.as_dict()) optimizer, **self._ema_config.as_dict())
if postprocessor: if postprocessor:
optimizer = postprocessor(optimizer) optimizer = postprocessor(optimizer)
assert isinstance( if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
optimizer, (tf.keras.optimizers.Optimizer, # tf.keras.optimizers.experimental only exist in tf-nightly.
tf.keras.optimizers.experimental.Optimizer) # The following check makes sure the function wont' break in older TF
), ('OptimizerFactory.build_optimizer returning a non-optimizer object: ' # version because of missing the experimental package.
'{}'.format(optimizer)) if hasattr(tf.keras.optimizers, 'experimental'):
if not isinstance(optimizer,
tf.keras.optimizers.experimental.Optimizer):
raise TypeError('OptimizerFactory.build_optimizer returning a '
'non-optimizer object: {}'.format(optimizer))
else:
raise TypeError('OptimizerFactory.build_optimizer returning a '
'non-optimizer object: {}'.format(optimizer))
return optimizer return optimizer
...@@ -140,6 +140,25 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -140,6 +140,25 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
optimizer_factory.OptimizerFactory( optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params)) optimization_config.OptimizationConfig(params))
def test_wrong_return_type(self):
optimizer_type = 'sgd'
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
with self.assertRaises(TypeError):
_ = opt_factory.build_optimizer(0.1, postprocessor=lambda x: None)
# TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`. # TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment