You need to sign in or sign up before continuing.
Commit 3d10262e authored by Chen Qian's avatar Chen Qian Committed by A. Unique TensorFlower
Browse files

Raise an explicit error if `decay` is set and new Keras optimizer is used.

PiperOrigin-RevId: 481980126
parent 6e2129f6
...@@ -236,6 +236,11 @@ class OptimizerFactory: ...@@ -236,6 +236,11 @@ class OptimizerFactory:
if use_legacy_optimizer: if use_legacy_optimizer:
optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
else: else:
if 'decay' in optimizer_dict:
raise ValueError(
'`decay` is deprecated in new Keras optimizer, please reflect the '
'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
'legacy optimizer.')
optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict) optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
if self._use_ema: if self._use_ema:
......
...@@ -68,6 +68,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -68,6 +68,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
expected_optimizer_config['learning_rate'] = 0.1 expected_optimizer_config['learning_rate'] = 0.1
opt_config = optimization_config.OptimizationConfig(params) opt_config = optimization_config.OptimizationConfig(params)
if optimizer_type == 'sgd':
# Delete unsupported arg `decay` from SGDConfig.
delattr(opt_config.optimizer.sgd, 'decay')
opt_factory = optimizer_factory.OptimizerFactory(opt_config) opt_factory = optimizer_factory.OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate() lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer( optimizer = opt_factory.build_optimizer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment