Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d10262e
Commit
3d10262e
authored
Oct 18, 2022
by
Chen Qian
Committed by
A. Unique TensorFlower
Oct 18, 2022
Browse files
Raise an explicit error if `decay` is set and new Keras optimizer is used.
PiperOrigin-RevId: 481980126
parent
6e2129f6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
0 deletions
+8
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+5
-0
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+3
-0
No files found.
official/modeling/optimization/optimizer_factory.py
View file @
3d10262e
...
@@ -236,6 +236,11 @@ class OptimizerFactory:
...
@@ -236,6 +236,11 @@ class OptimizerFactory:
if
use_legacy_optimizer
:
if
use_legacy_optimizer
:
optimizer
=
LEGACY_OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
optimizer
=
LEGACY_OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
else
:
else
:
if
'decay'
in
optimizer_dict
:
raise
ValueError
(
'`decay` is deprecated in new Keras optimizer, please reflect the '
'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
'legacy optimizer.'
)
optimizer
=
NEW_OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
optimizer
=
NEW_OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
if
self
.
_use_ema
:
if
self
.
_use_ema
:
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
3d10262e
...
@@ -68,6 +68,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -68,6 +68,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
expected_optimizer_config
[
'learning_rate'
]
=
0.1
expected_optimizer_config
[
'learning_rate'
]
=
0.1
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
if
optimizer_type
==
'sgd'
:
# Delete unsupported arg `decay` from SGDConfig.
delattr
(
opt_config
.
optimizer
.
sgd
,
'decay'
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
optimizer
=
opt_factory
.
build_optimizer
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment