Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
98f2d335
Commit
98f2d335
authored
Apr 01, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 438823508
parent
db3eead9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
5 deletions
+46
-5
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+10
-4
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+36
-1
No files found.
official/modeling/optimization/optimizer_factory.py
View file @
98f2d335
...
@@ -57,8 +57,8 @@ WARMUP_CLS = {
...
@@ -57,8 +57,8 @@ WARMUP_CLS = {
}
}
def
register_optimizer_cls
(
def
register_optimizer_cls
(
key
:
str
,
key
:
str
,
optimizer_config_cls
:
tf
.
keras
.
optimizers
.
Optimizer
):
optimizer_config_cls
:
tf
.
keras
.
optimizers
.
Optimizer
):
"""Register customize optimizer cls.
"""Register customize optimizer cls.
The user will still need to subclass data classes in
The user will still need to subclass data classes in
...
@@ -156,9 +156,12 @@ class OptimizerFactory:
...
@@ -156,9 +156,12 @@ class OptimizerFactory:
def
build_optimizer
(
def
build_optimizer
(
self
,
self
,
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
lr
:
Union
[
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
float
],
gradient_aggregator
:
Optional
[
Callable
[
[
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]],
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]]]
=
None
,
gradient_transformers
:
Optional
[
List
[
Callable
[
gradient_transformers
:
Optional
[
List
[
Callable
[
[
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]],
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]
[
List
[
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]]],
List
[
Tuple
[
tf
.
Tensor
,
]]]
=
None
,
tf
.
Tensor
]]
]]]
=
None
,
postprocessor
:
Optional
[
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
postprocessor
:
Optional
[
Callable
[[
tf
.
keras
.
optimizers
.
Optimizer
],
tf
.
keras
.
optimizers
.
Optimizer
]]
=
None
):
tf
.
keras
.
optimizers
.
Optimizer
]]
=
None
):
"""Build optimizer.
"""Build optimizer.
...
@@ -170,6 +173,7 @@ class OptimizerFactory:
...
@@ -170,6 +173,7 @@ class OptimizerFactory:
Args:
Args:
lr: A floating point value, or a
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_aggregator: Optional function to overwrite gradient aggregation.
gradient_transformers: Optional list of functions to use to transform
gradient_transformers: Optional list of functions to use to transform
gradients before applying updates to Variables. The functions are
gradients before applying updates to Variables. The functions are
applied after gradient_aggregator. The functions should accept and
applied after gradient_aggregator. The functions should accept and
...
@@ -193,6 +197,8 @@ class OptimizerFactory:
...
@@ -193,6 +197,8 @@ class OptimizerFactory:
del
optimizer_dict
[
'global_clipnorm'
]
del
optimizer_dict
[
'global_clipnorm'
]
optimizer_dict
[
'learning_rate'
]
=
lr
optimizer_dict
[
'learning_rate'
]
=
lr
if
gradient_aggregator
is
not
None
:
optimizer_dict
[
'gradient_aggregator'
]
=
gradient_aggregator
if
gradient_transformers
is
not
None
:
if
gradient_transformers
is
not
None
:
optimizer_dict
[
'gradient_transformers'
]
=
gradient_transformers
optimizer_dict
[
'gradient_transformers'
]
=
gradient_transformers
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
98f2d335
...
@@ -49,6 +49,39 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -49,6 +49,39 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIsInstance
(
optimizer
,
optimizer_cls
)
self
.
assertIsInstance
(
optimizer
,
optimizer_cls
)
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_gradient_aggregator
(
self
):
params
=
{
'optimizer'
:
{
'type'
:
'adam'
,
},
'learning_rate'
:
{
'type'
:
'constant'
,
'constant'
:
{
'learning_rate'
:
1.0
}
}
}
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
# Dummy function to zero out gradients.
zero_grads
=
lambda
gv
:
[(
tf
.
zeros_like
(
g
),
v
)
for
g
,
v
in
gv
]
optimizer
=
opt_factory
.
build_optimizer
(
lr
,
gradient_aggregator
=
zero_grads
)
var0
=
tf
.
Variable
([
1.0
,
2.0
])
var1
=
tf
.
Variable
([
3.0
,
4.0
])
grads0
=
tf
.
constant
([
1.0
,
1.0
])
grads1
=
tf
.
constant
([
1.0
,
1.0
])
grads_and_vars
=
list
(
zip
([
grads0
,
grads1
],
[
var0
,
var1
]))
optimizer
.
apply_gradients
(
grads_and_vars
)
self
.
assertAllClose
(
np
.
array
([
1.0
,
2.0
]),
var0
.
numpy
())
self
.
assertAllClose
(
np
.
array
([
3.0
,
4.0
]),
var1
.
numpy
())
@
parameterized
.
parameters
((
None
,
None
),
(
1.0
,
None
),
(
None
,
1.0
))
@
parameterized
.
parameters
((
None
,
None
),
(
1.0
,
None
),
(
None
,
1.0
))
def
test_gradient_clipping
(
self
,
clipnorm
,
clipvalue
):
def
test_gradient_clipping
(
self
,
clipnorm
,
clipvalue
):
params
=
{
params
=
{
...
@@ -418,7 +451,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -418,7 +451,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
}
}
}
expected_lr_step_values
=
[[
0
,
0.0
],
[
5000
,
1e-4
/
2.0
],
[
10000
,
1e-4
],
expected_lr_step_values
=
[[
0
,
0.0
],
[
5000
,
1e-4
/
2.0
],
[
10000
,
1e-4
],
[
20000
,
9.994863e-05
],
[
499999
,
5e-05
]]
[
20000
,
9.994863e-05
],
[
499999
,
5e-05
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
...
@@ -434,10 +467,12 @@ class OptimizerFactoryRegistryTest(tf.test.TestCase):
...
@@ -434,10 +467,12 @@ class OptimizerFactoryRegistryTest(tf.test.TestCase):
class
MyClass
():
class
MyClass
():
pass
pass
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
self
.
assertIn
(
'test'
,
optimizer_factory
.
OPTIMIZERS_CLS
)
self
.
assertIn
(
'test'
,
optimizer_factory
.
OPTIMIZERS_CLS
)
with
self
.
assertRaisesRegex
(
ValueError
,
'test already registered.*'
):
with
self
.
assertRaisesRegex
(
ValueError
,
'test already registered.*'
):
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
optimizer_factory
.
register_optimizer_cls
(
'test'
,
MyClass
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment