Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c9ab5a7a
Commit
c9ab5a7a
authored
Mar 03, 2021
by
Le Hou
Committed by
A. Unique TensorFlower
Mar 03, 2021
Browse files
Fix corner cases where LR schedules output Inf.
PiperOrigin-RevId: 360719881
parent
b3b0664b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
6 deletions
+16
-6
official/modeling/optimization/lr_schedule.py
official/modeling/optimization/lr_schedule.py
+12
-3
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+4
-3
No files found.
official/modeling/optimization/lr_schedule.py
View file @
c9ab5a7a
...
@@ -120,7 +120,14 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -120,7 +120,14 @@ class PolynomialWarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
# learning rate will be `global_step/num_warmup_steps * init_lr`.
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float
=
tf
.
cast
(
step
,
tf
.
float32
)
global_step_float
=
tf
.
cast
(
step
,
tf
.
float32
)
warmup_steps_float
=
tf
.
cast
(
self
.
_warmup_steps
,
tf
.
float32
)
warmup_steps_float
=
tf
.
cast
(
self
.
_warmup_steps
,
tf
.
float32
)
warmup_percent_done
=
global_step_float
/
warmup_steps_float
if
self
.
_warmup_steps
<=
0
:
warmup_percent_done
=
1.0
else
:
# A zero `step` may cause Inf. So make `step` positive.
step_non_zero
=
tf
.
math
.
maximum
(
global_step_float
,
1.0
)
warmup_percent_done
=
step_non_zero
/
warmup_steps_float
warmup_learning_rate
=
(
warmup_learning_rate
=
(
self
.
_initial_learning_rate
*
self
.
_initial_learning_rate
*
tf
.
math
.
pow
(
warmup_percent_done
,
self
.
_power
))
tf
.
math
.
pow
(
warmup_percent_done
,
self
.
_power
))
...
@@ -226,8 +233,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -226,8 +233,10 @@ class PowerAndLinearDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
with
tf
.
name_scope
(
self
.
_name
or
"PowerAndLinearDecay"
):
with
tf
.
name_scope
(
self
.
_name
or
"PowerAndLinearDecay"
):
step
=
tf
.
cast
(
step
,
tf
.
float32
)
step
=
tf
.
cast
(
step
,
tf
.
float32
)
learning_rate
=
self
.
_initial_learning_rate
learning_rate
=
self
.
_initial_learning_rate
learning_rate
*=
tf
.
math
.
pow
(
step
,
self
.
_power
)
# A zero `step` may cause Inf. So make `step` positive.
if
self
.
_linear_decay_fraction
>
0
:
step_non_zero
=
tf
.
math
.
maximum
(
step
,
1.0
)
learning_rate
*=
tf
.
math
.
pow
(
step_non_zero
,
self
.
_power
)
if
self
.
_total_decay_steps
*
self
.
_linear_decay_fraction
>
0
:
learning_rate
*=
tf
.
minimum
(
learning_rate
*=
tf
.
minimum
(
1.0
,
(
self
.
_total_decay_steps
-
step
)
/
1.0
,
(
self
.
_total_decay_steps
-
step
)
/
(
self
.
_total_decay_steps
*
self
.
_linear_decay_fraction
))
(
self
.
_total_decay_steps
*
self
.
_linear_decay_fraction
))
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
c9ab5a7a
...
@@ -313,7 +313,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -313,7 +313,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
for
step
,
value
in
expected_lr_step_values
:
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
,
places
=
6
)
def
test_power_lr_schedule
(
self
):
def
test_power_lr_schedule
(
self
):
params
=
{
params
=
{
...
@@ -331,7 +331,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -331,7 +331,7 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
}
}
}
expected_lr_step_values
=
[[
1
,
1.0
],
[
250
,
1.
/
250.
]]
expected_lr_step_values
=
[
[
0
,
1.0
],
[
1
,
1.0
],
[
250
,
1.
/
250.
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
@@ -357,7 +357,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -357,7 +357,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
}
}
}
expected_lr_step_values
=
[[
1
,
1.0
],
[
40
,
1.
/
40.
],
[
60
,
1.
/
60.
*
0.8
]]
expected_lr_step_values
=
[
[
0
,
1.0
],
[
1
,
1.0
],
[
40
,
1.
/
40.
],
[
60
,
1.
/
60.
*
0.8
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment