Unverified Commit a6758929 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Merge Resnet files (#3301)

parent 6c874e17
...@@ -23,8 +23,7 @@ import sys ...@@ -23,8 +23,7 @@ import sys
import tensorflow as tf import tensorflow as tf
import resnet_model import resnet
import resnet_shared
_HEIGHT = 32 _HEIGHT = 32
_WIDTH = 32 _WIDTH = 32
...@@ -152,8 +151,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -152,8 +151,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
############################################################################### ###############################################################################
# Running the model # Running the model
############################################################################### ###############################################################################
class Cifar10Model(resnet_model.Model): class Cifar10Model(resnet.Model):
def __init__(self, resnet_size, data_format=None): def __init__(self, resnet_size, data_format=None):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
""" """
...@@ -172,7 +170,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -172,7 +170,7 @@ class Cifar10Model(resnet_model.Model):
first_pool_stride=None, first_pool_stride=None,
second_pool_size=8, second_pool_size=8,
second_pool_stride=1, second_pool_stride=1,
block_fn=resnet_model.building_block, block_fn=resnet.building_block,
block_sizes=[num_blocks] * 3, block_sizes=[num_blocks] * 3,
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64, final_size=64,
...@@ -183,7 +181,7 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -183,7 +181,7 @@ def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10.""" """Model function for CIFAR-10."""
features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS])
learning_rate_fn = resnet_shared.learning_rate_with_decay( learning_rate_fn = resnet.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128, batch_size=params['batch_size'], batch_denom=128,
num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200], num_images=_NUM_IMAGES['train'], boundary_epochs=[100, 150, 200],
decay_rates=[1, 0.1, 0.01, 0.001]) decay_rates=[1, 0.1, 0.01, 0.001])
...@@ -200,23 +198,23 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -200,23 +198,23 @@ def cifar10_model_fn(features, labels, mode, params):
def loss_filter_fn(name): def loss_filter_fn(name):
return True return True
return resnet_shared.resnet_model_fn(features, labels, mode, Cifar10Model, return resnet.resnet_model_fn(features, labels, mode, Cifar10Model,
resnet_size=params['resnet_size'], resnet_size=params['resnet_size'],
weight_decay=weight_decay, weight_decay=weight_decay,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
loss_filter_fn=loss_filter_fn) loss_filter_fn=loss_filter_fn)
def main(unused_argv): def main(unused_argv):
resnet_shared.resnet_main(FLAGS, cifar10_model_fn, input_fn) resnet.resnet_main(FLAGS, cifar10_model_fn, input_fn)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = resnet_shared.ResnetArgParser() parser = resnet.ResnetArgParser()
# Set defaults that are reasonable for this model. # Set defaults that are reasonable for this model.
parser.set_defaults(data_dir='/tmp/cifar10_data', parser.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
......
...@@ -23,8 +23,7 @@ import sys ...@@ -23,8 +23,7 @@ import sys
import tensorflow as tf import tensorflow as tf
import resnet_model import resnet
import resnet_shared
import vgg_preprocessing import vgg_preprocessing
_DEFAULT_IMAGE_SIZE = 224 _DEFAULT_IMAGE_SIZE = 224
...@@ -129,17 +128,17 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -129,17 +128,17 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
############################################################################### ###############################################################################
# Running the model # Running the model
############################################################################### ###############################################################################
class ImagenetModel(resnet_model.Model): class ImagenetModel(resnet.Model):
def __init__(self, resnet_size, data_format=None): def __init__(self, resnet_size, data_format=None):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
""" """
# For bigger models, we want to use "bottleneck" layers # For bigger models, we want to use "bottleneck" layers
if resnet_size < 50: if resnet_size < 50:
block_fn = resnet_model.building_block block_fn = resnet.building_block
final_size = 512 final_size = 512
else: else:
block_fn = resnet_model.bottleneck_block block_fn = resnet.bottleneck_block
final_size = 2048 final_size = 2048
super(ImagenetModel, self).__init__( super(ImagenetModel, self).__init__(
...@@ -184,28 +183,28 @@ def _get_block_sizes(resnet_size): ...@@ -184,28 +183,28 @@ def _get_block_sizes(resnet_size):
def imagenet_model_fn(features, labels, mode, params): def imagenet_model_fn(features, labels, mode, params):
"""Our model_fn for ResNet to be used with our Estimator.""" """Our model_fn for ResNet to be used with our Estimator."""
learning_rate_fn = resnet_shared.learning_rate_with_decay( learning_rate_fn = resnet.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256, batch_size=params['batch_size'], batch_denom=256,
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
return resnet_shared.resnet_model_fn(features, labels, mode, ImagenetModel, return resnet.resnet_model_fn(features, labels, mode, ImagenetModel,
resnet_size=params['resnet_size'], resnet_size=params['resnet_size'],
weight_decay=1e-4, weight_decay=1e-4,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
loss_filter_fn=None) loss_filter_fn=None)
def main(unused_argv): def main(unused_argv):
resnet_shared.resnet_main(FLAGS, imagenet_model_fn, input_fn) resnet.resnet_main(FLAGS, imagenet_model_fn, input_fn)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = resnet_shared.ResnetArgParser( parser = resnet.ResnetArgParser(
resnet_size_choices=[18, 34, 50, 101, 152, 200]) resnet_size_choices=[18, 34, 50, 101, 152, 200])
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed) tf.app.run(argv=[sys.argv[0]] + unparsed)
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Contains definitions for the preactivation form of Residual Networks. """Contains definitions for the preactivation form of Residual Networks
(also known as ResNet v2).
Residual networks (ResNets) were originally proposed in: Residual networks (ResNets) were originally proposed in:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
...@@ -32,12 +33,18 @@ from __future__ import absolute_import ...@@ -32,12 +33,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os
import tensorflow as tf import tensorflow as tf
_BATCH_NORM_DECAY = 0.997 _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
################################################################################
# Functions building the ResNet model.
################################################################################
def batch_norm_relu(inputs, training, data_format): def batch_norm_relu(inputs, training, data_format):
"""Performs a batch normalization followed by a ReLU.""" """Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant performance boost. See # We set fused=True for a significant performance boost. See
...@@ -318,3 +325,223 @@ class Model(object): ...@@ -318,3 +325,223 @@ class Model(object):
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
################################################################################
# Functions for running training/eval/validation loops for the model.
################################################################################
def learning_rate_with_decay(
batch_size, batch_denom, num_images, boundary_epochs, decay_rates):
"""Get a learning rate that decays step-wise as training progresses.
Args:
batch_size: the number of examples processed in each training batch.
batch_denom: this value will be used to scale the base learning rate.
`0.1 * batch size` is divided by this number, such that when
batch_denom == batch_size, the initial learning rate will be 0.1.
num_images: total number of images that will be used for training.
boundary_epochs: list of ints representing the epochs at which we
decay the learning rate.
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. Should be the same length as
boundary_epochs.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate = 0.1 * batch_size / batch_denom
batches_per_epoch = num_images / batch_size
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(global_step):
global_step = tf.cast(global_step, tf.int32)
return tf.train.piecewise_constant(global_step, boundaries, vals)
return learning_rate_fn
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, loss_filter_fn=None):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
model = model_class(resnet_size, data_format)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
if not loss_filter_fn:
def loss_filter_fn(name):
return 'batch_normalization' not in name
# Add weight decay to the loss.
loss = cross_entropy + weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()
if loss_filter_fn(v.name)])
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum)
# Batch norm requires update ops to be added as a dependency to train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
else:
train_op = None
accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
def resnet_main(flags, model_function, input_function):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={
'resnet_size': flags.resnet_size,
'data_format': flags.data_format,
'batch_size': flags.batch_size,
})
for _ in range(flags.train_epochs // flags.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
print('Starting a training cycle.')
classifier.train(
input_fn=lambda: input_function(
True, flags.data_dir, flags.batch_size, flags.epochs_per_eval),
hooks=[logging_hook])
print('Starting to evaluate.')
# Evaluate the model and print results
eval_results = classifier.evaluate(input_fn=lambda: input_function(
False, flags.data_dir, flags.batch_size))
print(eval_results)
class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.
"""
def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__()
self.add_argument(
'--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.')
self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.')
self.add_argument(
'--resnet_size', type=int, default=50,
choices=resnet_size_choices,
help='The size of the ResNet model to use.')
self.add_argument(
'--train_epochs', type=int, default=100,
help='The number of epochs to use for training.')
self.add_argument(
'--epochs_per_eval', type=int, default=1,
help='The number of training epochs to run between evaluations.')
self.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
self.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for running Resnet that are shared across datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tensorflow as tf
def learning_rate_with_decay(
batch_size, batch_denom, num_images, boundary_epochs, decay_rates):
"""Get a learning rate that decays step-wise as training progresses.
Args:
batch_size: the number of examples processed in each training batch.
batch_denom: this value will be used to scale the base learning rate.
`0.1 * batch size` is divided by this number, such that when
batch_denom == batch_size, the initial learning rate will be 0.1.
num_images: total number of images that will be used for training.
boundary_epochs: list of ints representing the epochs at which we
decay the learning rate.
decay_rates: list of floats representing the decay rates to be used
for scaling the learning rate. Should be the same length as
boundary_epochs.
Returns:
Returns a function that takes a single argument - the number of batches
trained so far (global_step)- and returns the learning rate to be used
for training the next batch.
"""
initial_learning_rate = 0.1 * batch_size / batch_denom
batches_per_epoch = num_images / batch_size
# Multiply the learning rate by 0.1 at 100, 150, and 200 epochs.
boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
vals = [initial_learning_rate * decay for decay in decay_rates]
def learning_rate_fn(global_step):
global_step = tf.cast(global_step, tf.int32)
return tf.train.piecewise_constant(global_step, boundaries, vals)
return learning_rate_fn
def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, loss_filter_fn=None):
"""Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers
and uses that model to build the necessary EstimatorSpecs for
the `mode` in question. For training, this means building losses,
the optimizer, and the train op that get passed into the EstimatorSpec.
For evaluation and prediction, the EstimatorSpec is returned without
a train op, but with the necessary parameters for the given mode.
Args:
features: tensor representing input images
labels: tensor representing class labels for all input images
mode: current estimator mode; should be one of
`tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT`
model_class: a class representing a TensorFlow model that has a __call__
function. We assume here that this is a subclass of ResnetModel.
resnet_size: A single integer for the size of the ResNet model.
weight_decay: weight decay loss rate used to regularize learned variables.
learning_rate_fn: function that returns the current learning rate given
the current global_step
momentum: momentum term used for optimization
data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available.
loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded
from the loss.
Returns:
EstimatorSpec parameterized according to the input params and the
current mode.
"""
# Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6)
model = model_class(resnet_size, data_format)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
predictions = {
'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
logits=logits, onehot_labels=labels)
# Create a tensor named cross_entropy for logging purposes.
tf.identity(cross_entropy, name='cross_entropy')
tf.summary.scalar('cross_entropy', cross_entropy)
# If no loss_filter_fn is passed, assume we want the default behavior,
# which is that batch_normalization variables are excluded from loss.
if not loss_filter_fn:
def loss_filter_fn(name):
return 'batch_normalization' not in name
# Add weight decay to the loss.
loss = cross_entropy + weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables()
if loss_filter_fn(v.name)])
if mode == tf.estimator.ModeKeys.TRAIN:
global_step = tf.train.get_or_create_global_step()
learning_rate = learning_rate_fn(global_step)
# Create a tensor named learning_rate for logging purposes
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=momentum)
# Batch norm requires update ops to be added as a dependency to train_op
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step)
else:
train_op = None
accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics)
def resnet_main(flags, model_function, input_function):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={
'resnet_size': flags.resnet_size,
'data_format': flags.data_format,
'batch_size': flags.batch_size,
})
for _ in range(flags.train_epochs // flags.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
print('Starting a training cycle.')
classifier.train(
input_fn=lambda: input_function(
True, flags.data_dir, flags.batch_size, flags.epochs_per_eval),
hooks=[logging_hook])
print('Starting to evaluate.')
# Evaluate the model and print results
eval_results = classifier.evaluate(input_fn=lambda: input_function(
False, flags.data_dir, flags.batch_size))
print(eval_results)
class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.
"""
def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__()
self.add_argument(
'--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.')
self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.')
self.add_argument(
'--resnet_size', type=int, default=50,
choices=resnet_size_choices,
help='The size of the ResNet model to use.')
self.add_argument(
'--train_epochs', type=int, default=100,
help='The number of epochs to use for training.')
self.add_argument(
'--epochs_per_eval', type=int, default=1,
help='The number of training epochs to run between evaluations.')
self.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
self.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment