Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
38385b0a
Commit
38385b0a
authored
Sep 28, 2018
by
Toby Boyd
Committed by
Taylor Robie
Sep 28, 2018
Browse files
Update lr and default number epochs for CIFAR 10 (#5243)
parent
f505cecd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+7
-6
No files found.
official/resnet/cifar10_main.py
View file @
38385b0a
...
@@ -38,6 +38,7 @@ _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
...
@@ -38,6 +38,7 @@ _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES
=
10
_NUM_CLASSES
=
10
_NUM_DATA_FILES
=
5
_NUM_DATA_FILES
=
5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
_NUM_IMAGES
=
{
_NUM_IMAGES
=
{
'train'
:
50000
,
'train'
:
50000
,
'validation'
:
10000
,
'validation'
:
10000
,
...
@@ -193,14 +194,14 @@ class Cifar10Model(resnet_model.Model):
...
@@ -193,14 +194,14 @@ class Cifar10Model(resnet_model.Model):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
def
cifar10_model_fn
(
features
,
labels
,
mode
,
params
):
"""Model function for CIFAR-10."""
"""Model function for CIFAR-10."""
features
=
tf
.
reshape
(
features
,
[
-
1
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
features
=
tf
.
reshape
(
features
,
[
-
1
,
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
learning_rate_fn
=
resnet_run_loop
.
learning_rate_with_decay
(
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
batch_size
=
params
[
'batch_size'
],
batch_denom
=
128
,
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
100
,
150
,
200
],
num_images
=
_NUM_IMAGES
[
'train'
],
boundary_epochs
=
[
91
,
136
,
182
],
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
decay_rates
=
[
1
,
0.1
,
0.01
,
0.001
])
# We
use a we
ight decay of
0.0002, which performs bett
er
# Weight decay of
2e-4 diverges from 1e-4 decay used in the ResNet pap
er
#
than the 0.0001 that was originally suggested
.
#
and seems more stable in testing. The difference was nominal for ResNet-56
.
weight_decay
=
2e-4
weight_decay
=
2e-4
# Empirical testing showed that including batch_normalization variables
# Empirical testing showed that including batch_normalization variables
...
@@ -234,8 +235,8 @@ def define_cifar_flags():
...
@@ -234,8 +235,8 @@ def define_cifar_flags():
flags
.
adopt_module_key_flags
(
resnet_run_loop
)
flags
.
adopt_module_key_flags
(
resnet_run_loop
)
flags_core
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data'
,
flags_core
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data'
,
model_dir
=
'/tmp/cifar10_model'
,
model_dir
=
'/tmp/cifar10_model'
,
resnet_size
=
'
32
'
,
resnet_size
=
'
56
'
,
train_epochs
=
250
,
train_epochs
=
182
,
epochs_between_evals
=
10
,
epochs_between_evals
=
10
,
batch_size
=
128
)
batch_size
=
128
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment