Commit 38385b0a authored by Toby Boyd's avatar Toby Boyd Committed by Taylor Robie
Browse files

Update lr and default number epochs for CIFAR 10 (#5243)

parent f505cecd
...@@ -38,6 +38,7 @@ _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1 ...@@ -38,6 +38,7 @@ _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES = 10 _NUM_CLASSES = 10
_NUM_DATA_FILES = 5 _NUM_DATA_FILES = 5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
_NUM_IMAGES = { _NUM_IMAGES = {
'train': 50000, 'train': 50000,
'validation': 10000, 'validation': 10000,
...@@ -193,14 +194,14 @@ class Cifar10Model(resnet_model.Model): ...@@ -193,14 +194,14 @@ class Cifar10Model(resnet_model.Model):
def cifar10_model_fn(features, labels, mode, params): def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10.""" """Model function for CIFAR-10."""
features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128, batch_size=params['batch_size'], batch_denom=128,
num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200], num_images=_NUM_IMAGES['train'], boundary_epochs=[91, 136, 182],
decay_rates=[1, 0.1, 0.01, 0.001]) decay_rates=[1, 0.1, 0.01, 0.001])
# We use a weight decay of 0.0002, which performs better # Weight decay of 2e-4 diverges from 1e-4 decay used in the ResNet paper
# than the 0.0001 that was originally suggested. # and seems more stable in testing. The difference was nominal for ResNet-56.
weight_decay = 2e-4 weight_decay = 2e-4
# Empirical testing showed that including batch_normalization variables # Empirical testing showed that including batch_normalization variables
...@@ -234,8 +235,8 @@ def define_cifar_flags(): ...@@ -234,8 +235,8 @@ def define_cifar_flags():
flags.adopt_module_key_flags(resnet_run_loop) flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(data_dir='/tmp/cifar10_data', flags_core.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
resnet_size='32', resnet_size='56',
train_epochs=250, train_epochs=182,
epochs_between_evals=10, epochs_between_evals=10,
batch_size=128) batch_size=128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment