Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2560387f
You need to sign in or sign up before continuing.
Commit
2560387f
authored
Nov 15, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Nov 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 342548143
parent
e9057c4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
4 deletions
+15
-4
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+14
-3
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+1
-1
No files found.
official/modeling/optimization/optimizer_factory.py
View file @
2560387f
...
@@ -14,9 +14,10 @@
...
@@ -14,9 +14,10 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Optimizer factory class."""
"""Optimizer factory class."""
from
typing
import
Union
from
typing
import
Callable
,
Union
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
import
tensorflow_addons.optimizers
as
tfa_optimizers
...
@@ -126,9 +127,12 @@ class OptimizerFactory(object):
...
@@ -126,9 +127,12 @@ class OptimizerFactory(object):
return
lr
return
lr
@
gin
.
configurable
def
build_optimizer
(
def
build_optimizer
(
self
,
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
self
,
float
]):
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
postprocessor
:
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
tf
.
keras
.
optimizers
.
Optimizer
]
=
None
):
"""Build optimizer.
"""Build optimizer.
Builds optimizer from config. It takes learning rate as input, and builds
Builds optimizer from config. It takes learning rate as input, and builds
...
@@ -138,6 +142,8 @@ class OptimizerFactory(object):
...
@@ -138,6 +142,8 @@ class OptimizerFactory(object):
Args:
Args:
lr: A floating point value, or a
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
tf.keras.optimizers.schedules.LearningRateSchedule instance.
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
Returns:
Returns:
tf.keras.optimizers.Optimizer instance.
tf.keras.optimizers.Optimizer instance.
...
@@ -157,5 +163,10 @@ class OptimizerFactory(object):
...
@@ -157,5 +163,10 @@ class OptimizerFactory(object):
if
self
.
_use_ema
:
if
self
.
_use_ema
:
optimizer
=
ema_optimizer
.
ExponentialMovingAverage
(
optimizer
=
ema_optimizer
.
ExponentialMovingAverage
(
optimizer
,
**
self
.
_ema_config
.
as_dict
())
optimizer
,
**
self
.
_ema_config
.
as_dict
())
if
postprocessor
:
optimizer
=
postprocessor
(
optimizer
)
assert
isinstance
(
optimizer
,
tf
.
keras
.
optimizers
.
Optimizer
),
(
'OptimizerFactory.build_optimizer returning a non-optimizer object: '
'{}'
.
format
(
optimizer
))
return
optimizer
return
optimizer
official/modeling/optimization/optimizer_factory_test.py
View file @
2560387f
...
@@ -44,7 +44,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -44,7 +44,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
optimizer
=
opt_factory
.
build_optimizer
(
lr
)
optimizer
=
opt_factory
.
build_optimizer
(
lr
,
postprocessor
=
lambda
x
:
x
)
self
.
assertIsInstance
(
optimizer
,
optimizer_cls
)
self
.
assertIsInstance
(
optimizer
,
optimizer_cls
)
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment