Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5c96ad96
Commit
5c96ad96
authored
Jun 24, 2021
by
Yeqing Li
Committed by
A. Unique TensorFlower
Jun 24, 2021
Browse files
Adds the offset argument to the supported learning rate.
PiperOrigin-RevId: 381301573
parent
af924a4c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
116 additions
and
4 deletions
+116
-4
official/modeling/optimization/configs/learning_rate_config.py
...ial/modeling/optimization/configs/learning_rate_config.py
+8
-0
official/modeling/optimization/lr_schedule.py
official/modeling/optimization/lr_schedule.py
+69
-0
official/modeling/optimization/lr_schedule_test.py
official/modeling/optimization/lr_schedule_test.py
+35
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+4
-4
No files found.
official/modeling/optimization/configs/learning_rate_config.py
View file @
5c96ad96
...
@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config):
...
@@ -56,10 +56,12 @@ class StepwiseLrConfig(base_config.Config):
values[0] [boundaries[0], boundaries[1]] -> values[1]
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None.
end] -> values[n+1] Defaults to None.
offset: An int. The offset applied to steps. Defaults to 0.
"""
"""
name
:
str
=
'PiecewiseConstantDecay'
name
:
str
=
'PiecewiseConstantDecay'
boundaries
:
Optional
[
List
[
int
]]
=
None
boundaries
:
Optional
[
List
[
int
]]
=
None
values
:
Optional
[
List
[
float
]]
=
None
values
:
Optional
[
List
[
float
]]
=
None
offset
:
int
=
0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config):
...
@@ -76,12 +78,14 @@ class ExponentialLrConfig(base_config.Config):
decay_rate: A float. Defaults to None.
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
intervals. Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
"""
"""
name
:
str
=
'ExponentialDecay'
name
:
str
=
'ExponentialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
decay_steps
:
Optional
[
int
]
=
None
decay_steps
:
Optional
[
int
]
=
None
decay_rate
:
Optional
[
float
]
=
None
decay_rate
:
Optional
[
float
]
=
None
staircase
:
Optional
[
bool
]
=
None
staircase
:
Optional
[
bool
]
=
None
offset
:
int
=
0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -99,6 +103,7 @@ class PolynomialLrConfig(base_config.Config):
...
@@ -99,6 +103,7 @@ class PolynomialLrConfig(base_config.Config):
power: A float. The power of the polynomial. Defaults to linear, 1.0.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
Defaults to False.
Defaults to False.
offset: An int. The offset applied to steps. Defaults to 0.
"""
"""
name
:
str
=
'PolynomialDecay'
name
:
str
=
'PolynomialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
...
@@ -106,6 +111,7 @@ class PolynomialLrConfig(base_config.Config):
...
@@ -106,6 +111,7 @@ class PolynomialLrConfig(base_config.Config):
end_learning_rate
:
float
=
0.0001
end_learning_rate
:
float
=
0.0001
power
:
float
=
1.0
power
:
float
=
1.0
cycle
:
bool
=
False
cycle
:
bool
=
False
offset
:
int
=
0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config):
...
@@ -122,11 +128,13 @@ class CosineLrConfig(base_config.Config):
to None.
to None.
alpha: A float. Minimum learning rate value as a fraction of
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
initial_learning_rate.
offset: An int. The offset applied to steps. Defaults to 0.
"""
"""
name
:
str
=
'CosineDecay'
name
:
str
=
'CosineDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
decay_steps
:
Optional
[
int
]
=
None
decay_steps
:
Optional
[
int
]
=
None
alpha
:
float
=
0.0
alpha
:
float
=
0.0
offset
:
int
=
0
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/lr_schedule.py
View file @
5c96ad96
...
@@ -19,6 +19,75 @@ from typing import Mapping, Any, Union, Optional
...
@@ -19,6 +19,75 @@ from typing import Mapping, Any, Union, Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
def
_make_offset_wrapper
(
new_class_name
:
str
,
base_lr_class
):
"""Generates a offset wrapper of learning rate schedule.
It will returns a subclass of the the `base_lr_class`, the subclass takes an
`offset` argument in the constructor. When the new class instance is called,
the behavior is:
new_class_object(step) = base_lr_class_object(step - offset)
Example:
CosineDecayWithOffset = _make_offset_wrapper(
'CosineDecayWithOffset', tf.keras.experimental.CosineDecay)
# Use the lr:
lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
decay_steps=1000)
lr(101) # equals to tf.keras.experimental.CosineDecay(...)(101-100)
Args:
new_class_name: the name of the new class.
base_lr_class: the base learning rate schedule class. Should be subclass of
tf.keras.optimizers.schedules.LearningRateSchedule
Returns:
A new class (subclass of the base_lr_class) that can take an offset.
"""
assert
issubclass
(
base_lr_class
,
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
),
(
"base_lr_class should be subclass of keras "
f
"LearningRateSchedule, got
{
base_lr_class
}
"
)
# pylint: disable=protected-access,pointless-statement
def
offset_learning_rate_init
(
self
,
offset
=
0
,
**
kwargs
):
"""Construct learning rate schedule object.
When this object is called, its behavior is
self.__call__(step) == base_lr_class.__call__(step - offset)
Args:
self: this object.
offset: The offset when computing the learning rate schedule.
**kwargs: Pass through to base learning rate class constructor.
"""
base_lr_class
.
__init__
(
self
,
**
kwargs
)
self
.
_offset
=
offset
def
offset_learning_rate_call
(
self
,
step
):
step
=
tf
.
cast
(
step
-
self
.
_offset
,
tf
.
float32
)
return
base_lr_class
.
__call__
(
self
,
step
)
# pylint: enable=protected-access,pointless-statement
return
type
(
new_class_name
,
(
base_lr_class
,),
{
"base_lr_class"
:
base_lr_class
,
"__init__"
:
offset_learning_rate_init
,
"__call__"
:
offset_learning_rate_call
})
PiecewiseConstantDecayWithOffset
=
_make_offset_wrapper
(
"PiecewiseConstantDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
PiecewiseConstantDecay
)
PolynomialDecayWithOffset
=
_make_offset_wrapper
(
"PolynomialDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
PolynomialDecay
)
ExponentialDecayWithOffset
=
_make_offset_wrapper
(
"ExponentialDecayWithOffset"
,
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
)
CosineDecayWithOffset
=
_make_offset_wrapper
(
"CosineDecayWithOffset"
,
tf
.
keras
.
experimental
.
CosineDecay
)
class
LinearWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
class
LinearWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Linear warmup schedule."""
"""Linear warmup schedule."""
...
...
official/modeling/optimization/lr_schedule_test.py
View file @
5c96ad96
...
@@ -70,5 +70,40 @@ class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -70,5 +70,40 @@ class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
class
OffsetLearningRateTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
dict
(
class_name
=
lr_schedule
.
PiecewiseConstantDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
PolynomialDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
ExponentialDecayWithOffset
),
dict
(
class_name
=
lr_schedule
.
CosineDecayWithOffset
),
)
def
test_generated_docstring
(
self
,
class_name
):
self
.
assertNotEmpty
(
class_name
.
__init__
.
__doc__
)
@
parameterized
.
parameters
(
dict
(
class_name
=
lr_schedule
.
PiecewiseConstantDecayWithOffset
,
kwarg
=
dict
(
boundaries
=
[
50
,
80
],
values
=
[
1.0
,
0.5
,
0.1
])),
dict
(
class_name
=
lr_schedule
.
PolynomialDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
)),
dict
(
class_name
=
lr_schedule
.
ExponentialDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
,
decay_rate
=
0.5
)),
dict
(
class_name
=
lr_schedule
.
CosineDecayWithOffset
,
kwarg
=
dict
(
initial_learning_rate
=
1.0
,
decay_steps
=
100
)),
)
def
test_offset
(
self
,
class_name
,
kwarg
):
offset
=
10
offset_lr
=
class_name
(
offset
=
offset
,
**
kwarg
)
base_lr
=
class_name
.
base_lr_class
(
**
kwarg
)
self
.
assertIsInstance
(
offset_lr
,
class_name
)
for
step
in
range
(
10
,
101
,
10
):
self
.
assertEqual
(
offset_lr
(
step
),
base_lr
(
step
-
offset
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/modeling/optimization/optimizer_factory.py
View file @
5c96ad96
...
@@ -38,10 +38,10 @@ OPTIMIZERS_CLS = {
...
@@ -38,10 +38,10 @@ OPTIMIZERS_CLS = {
}
}
LR_CLS
=
{
LR_CLS
=
{
'stepwise'
:
tf
.
keras
.
optimizers
.
schedule
s
.
PiecewiseConstantDecay
,
'stepwise'
:
lr_
schedule
.
PiecewiseConstantDecay
WithOffset
,
'polynomial'
:
tf
.
keras
.
optimizers
.
schedule
s
.
PolynomialDecay
,
'polynomial'
:
lr_
schedule
.
PolynomialDecay
WithOffset
,
'exponential'
:
tf
.
keras
.
optimizers
.
schedule
s
.
ExponentialDecay
,
'exponential'
:
lr_
schedule
.
ExponentialDecay
WithOffset
,
'cosine'
:
tf
.
keras
.
experimental
.
CosineDecay
,
'cosine'
:
lr_schedule
.
CosineDecay
WithOffset
,
'power'
:
lr_schedule
.
DirectPowerDecay
,
'power'
:
lr_schedule
.
DirectPowerDecay
,
'power_linear'
:
lr_schedule
.
PowerAndLinearDecay
,
'power_linear'
:
lr_schedule
.
PowerAndLinearDecay
,
'power_with_offset'
:
lr_schedule
.
PowerDecayWithOffset
,
'power_with_offset'
:
lr_schedule
.
PowerDecayWithOffset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment