"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "bc849cc9c654a42526903ffc3dd6cbb8abddc787"
Commit 00f96640 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Adds check of existence of the experimental optimizer package.

PiperOrigin-RevId: 450601824
parent 55b43f27
...@@ -213,10 +213,17 @@ class OptimizerFactory: ...@@ -213,10 +213,17 @@ class OptimizerFactory:
optimizer, **self._ema_config.as_dict()) optimizer, **self._ema_config.as_dict())
if postprocessor: if postprocessor:
optimizer = postprocessor(optimizer) optimizer = postprocessor(optimizer)
assert isinstance( if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
optimizer, (tf.keras.optimizers.Optimizer, # tf.keras.optimizers.experimental only exist in tf-nightly.
tf.keras.optimizers.experimental.Optimizer) # The following check makes sure the function wont' break in older TF
), ('OptimizerFactory.build_optimizer returning a non-optimizer object: ' # version because of missing the experimental package.
'{}'.format(optimizer)) if hasattr(tf.keras.optimizers, 'experimental'):
if not isinstance(optimizer,
tf.keras.optimizers.experimental.Optimizer):
raise TypeError('OptimizerFactory.build_optimizer returning a '
'non-optimizer object: {}'.format(optimizer))
else:
raise TypeError('OptimizerFactory.build_optimizer returning a '
'non-optimizer object: {}'.format(optimizer))
return optimizer return optimizer
...@@ -140,6 +140,25 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -140,6 +140,25 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
optimizer_factory.OptimizerFactory( optimizer_factory.OptimizerFactory(
optimization_config.OptimizationConfig(params)) optimization_config.OptimizationConfig(params))
def test_wrong_return_type(self):
optimizer_type = 'sgd'
params = {
'optimizer': {
'type': optimizer_type
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.1
}
}
}
opt_config = optimization_config.OptimizationConfig(params)
opt_factory = optimizer_factory.OptimizerFactory(opt_config)
with self.assertRaises(TypeError):
_ = opt_factory.build_optimizer(0.1, postprocessor=lambda x: None)
# TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`. # TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment