Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3a6079ce
Commit
3a6079ce
authored
Nov 29, 2020
by
Zhichao Lu
Committed by
TF Object Detection Team
Nov 29, 2020
Browse files
Add an exponential decay learning rate schedule with warmup.
PiperOrigin-RevId: 344743049
parent
067d35f9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
12 deletions
+92
-12
research/object_detection/utils/learning_schedules.py
research/object_detection/utils/learning_schedules.py
+70
-12
research/object_detection/utils/learning_schedules_test.py
research/object_detection/utils/learning_schedules_test.py
+22
-0
No files found.
research/object_detection/utils/learning_schedules.py
View file @
3a6079ce
...
@@ -23,6 +23,14 @@ from six.moves import zip
...
@@ -23,6 +23,14 @@ from six.moves import zip
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
def
_learning_rate_return_value
(
eager_decay_rate
):
"""Helper function to return proper learning rate based on tf version."""
if
tf
.
executing_eagerly
():
return
eager_decay_rate
else
:
return
eager_decay_rate
()
def
exponential_decay_with_burnin
(
global_step
,
def
exponential_decay_with_burnin
(
global_step
,
learning_rate_base
,
learning_rate_base
,
learning_rate_decay_steps
,
learning_rate_decay_steps
,
...
@@ -76,10 +84,65 @@ def exponential_decay_with_burnin(global_step,
...
@@ -76,10 +84,65 @@ def exponential_decay_with_burnin(global_step,
tf
.
constant
(
burnin_learning_rate
),
tf
.
constant
(
burnin_learning_rate
),
post_burnin_learning_rate
),
min_learning_rate
,
name
=
'learning_rate'
)
post_burnin_learning_rate
),
min_learning_rate
,
name
=
'learning_rate'
)
if
tf
.
executing_eagerly
():
return
_learning_rate_return_value
(
eager_decay_rate
)
return
eager_decay_rate
else
:
return
eager_decay_rate
()
def
exponential_decay_with_warmup
(
global_step
,
learning_rate_base
,
learning_rate_decay_steps
,
learning_rate_decay_factor
,
warmup_learning_rate
=
0.0
,
warmup_steps
=
0
,
min_learning_rate
=
0.0
,
staircase
=
True
):
"""Exponential decay schedule with warm up period.
Args:
global_step: int tensor representing global step.
learning_rate_base: base learning rate.
learning_rate_decay_steps: steps to take between decaying the learning rate.
Note that this includes the number of burn-in steps.
learning_rate_decay_factor: multiplicative factor by which to decay learning
rate.
warmup_learning_rate: initial learning rate during warmup period.
warmup_steps: number of steps to use warmup learning rate.
min_learning_rate: the minimum learning rate.
staircase: whether use staircase decay.
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
"""
def
eager_decay_rate
():
"""Callable to compute the learning rate."""
post_warmup_learning_rate
=
tf
.
train
.
exponential_decay
(
learning_rate_base
,
global_step
-
warmup_steps
,
learning_rate_decay_steps
,
learning_rate_decay_factor
,
staircase
=
staircase
)
if
callable
(
post_warmup_learning_rate
):
post_warmup_learning_rate
=
post_warmup_learning_rate
()
if
learning_rate_base
<
warmup_learning_rate
:
raise
ValueError
(
'learning_rate_base must be larger or equal to '
'warmup_learning_rate.'
)
slope
=
(
learning_rate_base
-
warmup_learning_rate
)
/
warmup_steps
warmup_rate
=
slope
*
tf
.
cast
(
global_step
,
tf
.
float32
)
+
warmup_learning_rate
learning_rate
=
tf
.
where
(
tf
.
less
(
tf
.
cast
(
global_step
,
tf
.
int32
),
tf
.
constant
(
warmup_steps
)),
warmup_rate
,
tf
.
maximum
(
post_warmup_learning_rate
,
min_learning_rate
),
name
=
'learning_rate'
)
return
learning_rate
return
_learning_rate_return_value
(
eager_decay_rate
)
def
cosine_decay_with_warmup
(
global_step
,
def
cosine_decay_with_warmup
(
global_step
,
...
@@ -142,10 +205,7 @@ def cosine_decay_with_warmup(global_step,
...
@@ -142,10 +205,7 @@ def cosine_decay_with_warmup(global_step,
return
tf
.
where
(
global_step
>
total_steps
,
0.0
,
learning_rate
,
return
tf
.
where
(
global_step
>
total_steps
,
0.0
,
learning_rate
,
name
=
'learning_rate'
)
name
=
'learning_rate'
)
if
tf
.
executing_eagerly
():
return
_learning_rate_return_value
(
eager_decay_rate
)
return
eager_decay_rate
else
:
return
eager_decay_rate
()
def
manual_stepping
(
global_step
,
boundaries
,
rates
,
warmup
=
False
):
def
manual_stepping
(
global_step
,
boundaries
,
rates
,
warmup
=
False
):
...
@@ -212,7 +272,5 @@ def manual_stepping(global_step, boundaries, rates, warmup=False):
...
@@ -212,7 +272,5 @@ def manual_stepping(global_step, boundaries, rates, warmup=False):
[
0
]
*
num_boundaries
))
[
0
]
*
num_boundaries
))
return
tf
.
reduce_sum
(
rates
*
tf
.
one_hot
(
rate_index
,
depth
=
num_boundaries
),
return
tf
.
reduce_sum
(
rates
*
tf
.
one_hot
(
rate_index
,
depth
=
num_boundaries
),
name
=
'learning_rate'
)
name
=
'learning_rate'
)
if
tf
.
executing_eagerly
():
return
eager_decay_rate
return
_learning_rate_return_value
(
eager_decay_rate
)
else
:
return
eager_decay_rate
()
research/object_detection/utils/learning_schedules_test.py
View file @
3a6079ce
...
@@ -50,6 +50,28 @@ class LearningSchedulesTest(test_case.TestCase):
...
@@ -50,6 +50,28 @@ class LearningSchedulesTest(test_case.TestCase):
exp_rates
=
[.
5
,
.
5
,
1
,
1
,
1
,
.
1
,
.
1
,
.
1
,
.
05
]
exp_rates
=
[.
5
,
.
5
,
1
,
1
,
1
,
.
1
,
.
1
,
.
1
,
.
05
]
self
.
assertAllClose
(
output_rates
,
exp_rates
,
rtol
=
1e-4
)
self
.
assertAllClose
(
output_rates
,
exp_rates
,
rtol
=
1e-4
)
def
testExponentialDecayWithWarmup
(
self
):
def
graph_fn
(
global_step
):
learning_rate_base
=
1.0
learning_rate_decay_steps
=
3
learning_rate_decay_factor
=
.
1
warmup_learning_rate
=
.
5
warmup_steps
=
2
min_learning_rate
=
.
05
learning_rate
=
learning_schedules
.
exponential_decay_with_warmup
(
global_step
,
learning_rate_base
,
learning_rate_decay_steps
,
learning_rate_decay_factor
,
warmup_learning_rate
,
warmup_steps
,
min_learning_rate
)
assert
learning_rate
.
op
.
name
.
endswith
(
'learning_rate'
)
return
(
learning_rate
,)
output_rates
=
[
self
.
execute
(
graph_fn
,
[
np
.
array
(
i
).
astype
(
np
.
int64
)])
for
i
in
range
(
9
)
]
exp_rates
=
[.
5
,
.
75
,
1
,
1
,
1
,
.
1
,
.
1
,
.
1
,
.
05
]
self
.
assertAllClose
(
output_rates
,
exp_rates
,
rtol
=
1e-4
)
def
testCosineDecayWithWarmup
(
self
):
def
testCosineDecayWithWarmup
(
self
):
def
graph_fn
(
global_step
):
def
graph_fn
(
global_step
):
learning_rate_base
=
1.0
learning_rate_base
=
1.0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment