Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bed0e3de
Commit
bed0e3de
authored
Sep 20, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Sep 20, 2021
Browse files
Internal change
PiperOrigin-RevId: 397848064
parent
3318370e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
2 deletions
+14
-2
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+14
-2
No files found.
official/modeling/optimization/optimizer_factory.py
View file @
bed0e3de
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Optimizer factory class."""
"""Optimizer factory class."""
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
,
List
,
Tuple
import
gin
import
gin
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -139,6 +139,9 @@ class OptimizerFactory:
...
@@ -139,6 +139,9 @@ class OptimizerFactory:
def
build_optimizer
(
def
build_optimizer
(
self
,
self
,
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
gradient_transformers
:
Optional
[
List
[
Callable
[
[
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]],
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]
]]]
=
None
,
postprocessor
:
Optional
[
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
postprocessor
:
Optional
[
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
tf
.
keras
.
optimizers
.
Optimizer
]]
=
None
):
tf
.
keras
.
optimizers
.
Optimizer
]]
=
None
):
"""Build optimizer.
"""Build optimizer.
...
@@ -150,6 +153,11 @@ class OptimizerFactory:
...
@@ -150,6 +153,11 @@ class OptimizerFactory:
Args:
Args:
lr: A floating point value, or a
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
return a list of (gradient, variable) tuples. clipvalue, clipnorm,
global_clipnorm should not be set when gradient_transformers is passed.
postprocessor: An optional function for postprocessing the optimizer. It
postprocessor: An optional function for postprocessing the optimizer. It
takes an optimizer and returns an optimizer.
takes an optimizer and returns an optimizer.
...
@@ -158,13 +166,17 @@ class OptimizerFactory:
...
@@ -158,13 +166,17 @@ class OptimizerFactory:
"""
"""
optimizer_dict
=
self
.
_optimizer_config
.
as_dict
()
optimizer_dict
=
self
.
_optimizer_config
.
as_dict
()
## Delete clipnorm
and
clipvalue if None
## Delete clipnorm
,
clipvalue
, global_clipnorm
if None
if
optimizer_dict
[
'clipnorm'
]
is
None
:
if
optimizer_dict
[
'clipnorm'
]
is
None
:
del
optimizer_dict
[
'clipnorm'
]
del
optimizer_dict
[
'clipnorm'
]
if
optimizer_dict
[
'clipvalue'
]
is
None
:
if
optimizer_dict
[
'clipvalue'
]
is
None
:
del
optimizer_dict
[
'clipvalue'
]
del
optimizer_dict
[
'clipvalue'
]
if
optimizer_dict
[
'global_clipnorm'
]
is
None
:
del
optimizer_dict
[
'global_clipnorm'
]
optimizer_dict
[
'learning_rate'
]
=
lr
optimizer_dict
[
'learning_rate'
]
=
lr
if
gradient_transformers
is
not
None
:
optimizer_dict
[
'gradient_transformers'
]
=
gradient_transformers
optimizer
=
OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
optimizer
=
OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment