Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
39f98e30
Commit
39f98e30
authored
Dec 15, 2020
by
Le Hou
Committed by
A. Unique TensorFlower
Dec 15, 2020
Browse files
Internal change
PiperOrigin-RevId: 347634117
parent
21ab9cf7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
0 deletions
+107
-0
official/modeling/optimization/configs/learning_rate_config.py
...ial/modeling/optimization/configs/learning_rate_config.py
+21
-0
official/modeling/optimization/configs/optimization_config.py
...cial/modeling/optimization/configs/optimization_config.py
+4
-0
official/modeling/optimization/lr_schedule.py
official/modeling/optimization/lr_schedule.py
+55
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+1
-0
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+26
-0
No files found.
official/modeling/optimization/configs/learning_rate_config.py
View file @
39f98e30
...
@@ -146,6 +146,27 @@ class DirectPowerLrConfig(base_config.Config):
...
@@ -146,6 +146,27 @@ class DirectPowerLrConfig(base_config.Config):
power
:
float
=
-
0.5
power
:
float
=
-
0.5
@
dataclasses
.
dataclass
class
PowerAndLinearDecayLrConfig
(
base_config
.
Config
):
"""Configuration for DirectPower learning rate decay.
This class configures a schedule following follows lr * (step)^power for the
first total_decay_steps * (1 - linear_decay_fraction) steps, and follows
lr * (step)^power * (total_decay_steps - step) / (total_decay_steps *
linear_decay_fraction) for the rest of the steps.
Attributes:
name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
power: A float. Defaults to -0.5, for sqrt decay.
"""
name
:
str
=
'PowerAndLinearDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
total_decay_steps
:
Optional
[
int
]
=
None
power
:
float
=
-
0.5
linear_decay_fraction
:
float
=
0.1
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
LinearWarmupConfig
(
base_config
.
Config
):
class
LinearWarmupConfig
(
base_config
.
Config
):
"""Configuration for linear warmup schedule config.
"""Configuration for linear warmup schedule config.
...
...
official/modeling/optimization/configs/optimization_config.py
View file @
39f98e30
...
@@ -61,6 +61,8 @@ class LrConfig(oneof.OneOfConfig):
...
@@ -61,6 +61,8 @@ class LrConfig(oneof.OneOfConfig):
polynomial: polynomial learning rate config.
polynomial: polynomial learning rate config.
cosine: cosine learning rate config.
cosine: cosine learning rate config.
power: step^power learning rate config.
power: step^power learning rate config.
power_linear: learning rate config of step^power followed by
step^power*linear.
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
constant
:
lr_cfg
.
ConstantLrConfig
=
lr_cfg
.
ConstantLrConfig
()
constant
:
lr_cfg
.
ConstantLrConfig
=
lr_cfg
.
ConstantLrConfig
()
...
@@ -69,6 +71,8 @@ class LrConfig(oneof.OneOfConfig):
...
@@ -69,6 +71,8 @@ class LrConfig(oneof.OneOfConfig):
polynomial
:
lr_cfg
.
PolynomialLrConfig
=
lr_cfg
.
PolynomialLrConfig
()
polynomial
:
lr_cfg
.
PolynomialLrConfig
=
lr_cfg
.
PolynomialLrConfig
()
cosine
:
lr_cfg
.
CosineLrConfig
=
lr_cfg
.
CosineLrConfig
()
cosine
:
lr_cfg
.
CosineLrConfig
=
lr_cfg
.
CosineLrConfig
()
power
:
lr_cfg
.
DirectPowerLrConfig
=
lr_cfg
.
DirectPowerLrConfig
()
power
:
lr_cfg
.
DirectPowerLrConfig
=
lr_cfg
.
DirectPowerLrConfig
()
power_linear
:
lr_cfg
.
PowerAndLinearDecayLrConfig
=
(
lr_cfg
.
PowerAndLinearDecayLrConfig
())
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/lr_schedule.py
View file @
39f98e30
...
@@ -188,3 +188,58 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -188,3 +188,58 @@ class DirectPowerDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"power"
:
self
.
_power
,
"power"
:
self
.
_power
,
"name"
:
self
.
_name
,
"name"
:
self
.
_name
,
}
}
class
PowerAndLinearDecay
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Learning rate schedule with multiplied by linear decay at the end.
follows lr * (step)^power for the first total_decay_steps *
(1 - linear_decay_fraction) steps, and follows lr * (step)^power *
(total_decay_steps - step) / (total_decay_steps * linear_decay_fraction)
for the rest of the steps.
"""
def
__init__
(
self
,
initial_learning_rate
:
float
,
total_decay_steps
:
int
,
power
:
float
=
1.0
,
linear_decay_fraction
:
float
=
0.1
,
name
:
str
=
"PowerAndLinearDecay"
):
"""Initialize configuration of the learning rate schedule.
Args:
initial_learning_rate: A float, the initial learning rate.
total_decay_steps: The total number of steps for power + linear decay.
power: A float, the number of steps required for linear warmup.
linear_decay_fraction: A float, in the last `linear_decay_fraction` steps,
the learning rate will be multiplied by a linear decay.
name: Optional, name of warmup schedule.
"""
super
(
PowerAndLinearDecay
,
self
).
__init__
()
self
.
_initial_learning_rate
=
initial_learning_rate
self
.
_total_decay_steps
=
total_decay_steps
self
.
_power
=
power
self
.
_linear_decay_fraction
=
linear_decay_fraction
self
.
_name
=
name
def
__call__
(
self
,
step
):
with
tf
.
name_scope
(
self
.
_name
or
"PowerAndLinearDecay"
):
step
=
tf
.
cast
(
step
,
tf
.
float32
)
learning_rate
=
self
.
_initial_learning_rate
learning_rate
*=
tf
.
math
.
pow
(
step
,
self
.
_power
)
if
self
.
_linear_decay_fraction
>
0
:
learning_rate
*=
tf
.
minimum
(
1.0
,
(
self
.
_total_decay_steps
-
step
)
/
(
self
.
_total_decay_steps
*
self
.
_linear_decay_fraction
))
learning_rate
=
tf
.
maximum
(
0.0
,
learning_rate
)
return
learning_rate
def
get_config
(
self
):
"""Get the configuration of the learning rate schedule."""
return
{
"initial_learning_rate"
:
self
.
_initial_learning_rate
,
"total_decay_steps"
:
self
.
_total_decay_steps
,
"power"
:
self
.
_power
,
"linear_decay_fraction"
:
self
.
_linear_decay_fraction
,
"name"
:
self
.
_name
,
}
official/modeling/optimization/optimizer_factory.py
View file @
39f98e30
...
@@ -40,6 +40,7 @@ LR_CLS = {
...
@@ -40,6 +40,7 @@ LR_CLS = {
'exponential'
:
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
,
'exponential'
:
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
,
'cosine'
:
tf
.
keras
.
experimental
.
CosineDecay
,
'cosine'
:
tf
.
keras
.
experimental
.
CosineDecay
,
'power'
:
lr_schedule
.
DirectPowerDecay
,
'power'
:
lr_schedule
.
DirectPowerDecay
,
'power_linear'
:
lr_schedule
.
PowerAndLinearDecay
,
}
}
WARMUP_CLS
=
{
WARMUP_CLS
=
{
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
39f98e30
...
@@ -340,6 +340,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -340,6 +340,32 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for
step
,
value
in
expected_lr_step_values
:
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
def
test_power_linear_lr_schedule
(
self
):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'power_linear'
,
'power_linear'
:
{
'initial_learning_rate'
:
1.0
,
'power'
:
-
1.0
,
'linear_decay_fraction'
:
0.5
,
'total_decay_steps'
:
100
,
}
}
}
expected_lr_step_values
=
[[
1
,
1.0
],
[
40
,
1.
/
40.
],
[
60
,
1.
/
60.
*
0.8
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment