Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e932712b
Commit
e932712b
authored
Dec 03, 2018
by
Priya Gupta
Browse files
Change LR schedule to adjust according to batch size
parent
746a927c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
20 deletions
+33
-20
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+10
-9
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+23
-11
No files found.
official/resnet/keras/keras_cifar_main.py
View file @
e932712b
...
@@ -83,10 +83,10 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -83,10 +83,10 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
0.1
,
91
),
(
0.01
,
136
),
(
0.001
,
182
)
(
0.1
,
91
),
(
0.01
,
136
),
(
0.001
,
182
)
]
]
NUM_GPUS
=
flags_core
.
get_num_gpus
(
flags
.
FLAGS
)
BASE_LEARNING_RATE
=
0.1
*
NUM_GPUS
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
):
BASE_LEARNING_RATE
=
0.1
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
"""Handles linear scaling rule, gradual warmup, and LR decay.
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
The learning rate starts at 0, then it increases linearly per step.
...
@@ -115,11 +115,11 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
...
@@ -115,11 +115,11 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
# break
# break
# return learning_rate
# return learning_rate
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
initial_learning_rate
=
BASE_LEARNING_RATE
*
batch_size
/
128
learning_rate
=
BASE_LEARNING_RATE
learning_rate
=
initial_learning_rate
for
mult
,
start_epoch
in
LR_SCHEDULE
:
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
epoch
>=
start_epoch
:
if
current_
epoch
>=
start_epoch
:
learning_rate
=
BASE_LEARNING_RATE
*
mult
learning_rate
=
initial_learning_rate
*
mult
else
:
else
:
break
break
return
learning_rate
return
learning_rate
...
@@ -140,6 +140,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -140,6 +140,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super
(
LearningRateBatchScheduler
,
self
).
__init__
()
super
(
LearningRateBatchScheduler
,
self
).
__init__
()
self
.
schedule
=
schedule
self
.
schedule
=
schedule
self
.
batches_per_epoch
=
num_images
/
batch_size
self
.
batches_per_epoch
=
num_images
/
batch_size
self
.
batch_size
=
batch_size
self
.
epochs
=
-
1
self
.
epochs
=
-
1
self
.
prev_lr
=
-
1
self
.
prev_lr
=
-
1
...
@@ -149,7 +150,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -149,7 +150,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self
.
epochs
+=
1
self
.
epochs
+=
1
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
lr
=
self
.
schedule
(
self
.
epochs
,
batch
,
self
.
batches_per_epoch
)
lr
=
self
.
schedule
(
self
.
epochs
,
batch
,
self
.
batches_per_epoch
,
self
.
batch_size
)
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
raise
ValueError
(
'The output of the "schedule" function should be float.'
)
raise
ValueError
(
'The output of the "schedule" function should be float.'
)
if
lr
!=
self
.
prev_lr
:
if
lr
!=
self
.
prev_lr
:
...
@@ -273,7 +274,7 @@ def run_cifar_with_keras(flags_obj):
...
@@ -273,7 +274,7 @@ def run_cifar_with_keras(flags_obj):
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
flags_obj
.
model_dir
)
log_dir
=
flags_obj
.
model_dir
)
#
update_freq="batch") # Add this if want per batch logging.
#update_freq="batch") # Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule
,
learning_rate_schedule
,
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
e932712b
...
@@ -81,9 +81,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
...
@@ -81,9 +81,9 @@ class TimeHistory(tf.keras.callbacks.Callback):
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
]
]
BASE_LEARNING_RATE
=
0.4
#
0.128
BASE_LEARNING_RATE
=
0.128
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
):
def
learning_rate_schedule
(
current_epoch
,
current_batch
,
batches_per_epoch
,
batch_size
):
"""Handles linear scaling rule, gradual warmup, and LR decay.
"""Handles linear scaling rule, gradual warmup, and LR decay.
The learning rate starts at 0, then it increases linearly per step.
The learning rate starts at 0, then it increases linearly per step.
...
@@ -100,14 +100,15 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
...
@@ -100,14 +100,15 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch):
Returns:
Returns:
Adjusted learning rate.
Adjusted learning rate.
"""
"""
initial_learning_rate
=
BASE_LEARNING_RATE
*
batch_size
/
256
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
if
epoch
<
warmup_end_epoch
:
if
epoch
<
warmup_end_epoch
:
# Learning rate increases linearly per step.
# Learning rate increases linearly per step.
return
BASE_LEARNING_RATE
*
warmup_lr_multiplier
*
epoch
/
warmup_end_epoch
return
initial_learning_rate
*
warmup_lr_multiplier
*
epoch
/
warmup_end_epoch
for
mult
,
start_epoch
in
LR_SCHEDULE
:
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
epoch
>=
start_epoch
:
if
epoch
>=
start_epoch
:
learning_rate
=
BASE_LEARNING_RATE
*
mult
learning_rate
=
initial_learning_rate
*
mult
else
:
else
:
break
break
return
learning_rate
return
learning_rate
...
@@ -128,6 +129,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -128,6 +129,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
super
(
LearningRateBatchScheduler
,
self
).
__init__
()
super
(
LearningRateBatchScheduler
,
self
).
__init__
()
self
.
schedule
=
schedule
self
.
schedule
=
schedule
self
.
batches_per_epoch
=
num_images
/
batch_size
self
.
batches_per_epoch
=
num_images
/
batch_size
self
.
batch_size
=
batch_size
self
.
epochs
=
-
1
self
.
epochs
=
-
1
self
.
prev_lr
=
-
1
self
.
prev_lr
=
-
1
...
@@ -137,7 +139,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
...
@@ -137,7 +139,7 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self
.
epochs
+=
1
self
.
epochs
+=
1
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
lr
=
self
.
schedule
(
self
.
epochs
,
batch
,
self
.
batches_per_epoch
)
lr
=
self
.
schedule
(
self
.
epochs
,
batch
,
self
.
batches_per_epoch
,
self
.
batch_size
)
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
if
not
isinstance
(
lr
,
(
float
,
np
.
float32
,
np
.
float64
)):
raise
ValueError
(
'The output of the "schedule" function should be float.'
)
raise
ValueError
(
'The output of the "schedule" function should be float.'
)
if
lr
!=
self
.
prev_lr
:
if
lr
!=
self
.
prev_lr
:
...
@@ -187,6 +189,15 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -187,6 +189,15 @@ def run_imagenet_with_keras(flags_obj):
Raises:
Raises:
ValueError: If fp16 is passed as it is not currently supported.
ValueError: If fp16 is passed as it is not currently supported.
"""
"""
# Set all random seeds to fixed values.
import
random
import
numpy
as
np
seed
=
87654321
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
tf
.
random
.
set_random_seed
(
seed
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
if
dtype
==
'fp16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
...
@@ -239,10 +250,11 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -239,10 +250,11 @@ def run_imagenet_with_keras(flags_obj):
# opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
# opt = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
# I am setting an initial LR of 0.001 since this will be reset
# I am setting an initial LR of 0.001 since this will be reset
# at the beginning of the training loop.
# at the beginning of the training loop.
opt
=
gradient_descent_v2
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
#
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer:
# TF Optimizer:
# opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9)
learning_rate
=
BASE_LEARNING_RATE
*
flags_obj
.
batch_size
/
256
opt
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
0.9
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
num_gpus
=
flags_obj
.
num_gpus
)
num_gpus
=
flags_obj
.
num_gpus
)
...
@@ -264,8 +276,8 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -264,8 +276,8 @@ def run_imagenet_with_keras(flags_obj):
time_callback
=
TimeHistory
(
flags_obj
.
batch_size
)
time_callback
=
TimeHistory
(
flags_obj
.
batch_size
)
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
flags_obj
.
model_dir
)
log_dir
=
flags_obj
.
model_dir
,
#
update_freq="batch") # Add this if want per batch logging.
update_freq
=
"batch"
)
# Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule
,
learning_rate_schedule
,
...
@@ -280,7 +292,7 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -280,7 +292,7 @@ def run_imagenet_with_keras(flags_obj):
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
steps_per_epoch
,
callbacks
=
[
callbacks
=
[
time_callback
,
time_callback
,
lr_callback
,
#
lr_callback,
tesorboard_callback
tesorboard_callback
],
],
verbose
=
1
)
verbose
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment