Commit 7e9e15ad authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #2056 from tfboyd/cifar_mkl

Added data_format flag to support MKL and other interesting tests
parents 3bf85a4e 90fbe70e
...@@ -74,8 +74,8 @@ class Cifar10DataSet(object): ...@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
dataset = tf.contrib.data.TFRecordDataset(filenames).repeat() dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
# Parse records. # Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size, dataset = dataset.map(
output_buffer_size=2 * batch_size) self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)
# Potentially shuffle records. # Potentially shuffle records.
if self.subset == 'train': if self.subset == 'train':
......
...@@ -32,21 +32,21 @@ import argparse ...@@ -32,21 +32,21 @@ import argparse
import functools import functools
import itertools import itertools
import os import os
import six
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import cifar10 import cifar10
import cifar10_model import cifar10_model
import cifar10_utils import cifar10_utils
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
def get_model_fn(num_gpus, variable_strategy, num_workers, sync): def get_model_fn(num_gpus, variable_strategy, num_workers):
"""Returns a function that will build the resnet model."""
def _resnet_model_fn(features, labels, mode, params): def _resnet_model_fn(features, labels, mode, params):
"""Resnet model body. """Resnet model body.
...@@ -74,6 +74,16 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -74,6 +74,16 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
tower_gradvars = [] tower_gradvars = []
tower_preds = [] tower_preds = []
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
data_format = params.data_format
if not data_format:
if num_gpus == 0:
data_format = 'channels_last'
else:
data_format = 'channels_first'
if num_gpus == 0: if num_gpus == 0:
num_devices = 1 num_devices = 1
device_type = 'cpu' device_type = 'cpu'
...@@ -91,21 +101,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -91,21 +101,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
ps_device_type='gpu', ps_device_type='gpu',
worker_device=worker_device, worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy( ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus, num_gpus, tf.contrib.training.byte_size_load_fn))
tf.contrib.training.byte_size_load_fn
)
)
with tf.variable_scope('resnet', reuse=bool(i != 0)): with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope: with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter): with tf.device(device_setter):
loss, gradvars, preds = _tower_fn( loss, gradvars, preds = _tower_fn(
is_training, is_training, weight_decay, tower_features[i], tower_labels[i],
weight_decay, data_format, params.num_layers, params.batch_norm_decay,
tower_features[i],
tower_labels[i],
(device_type == 'cpu'),
params.num_layers,
params.batch_norm_decay,
params.batch_norm_epsilon) params.batch_norm_epsilon)
tower_losses.append(loss) tower_losses.append(loss)
tower_gradvars.append(gradvars) tower_gradvars.append(gradvars)
...@@ -136,7 +138,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -136,7 +138,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads)) avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var)) gradvars.append((avg_grad, var))
# Device that runs the ops to apply global gradient updates. # Device that runs the ops to apply global gradient updates.
consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0' consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
with tf.device(consolidation_device): with tf.device(consolidation_device):
...@@ -159,10 +160,9 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -159,10 +160,9 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
learning_rate=learning_rate, momentum=momentum) learning_rate=learning_rate, momentum=momentum)
chief_hooks = [] chief_hooks = []
if sync: if params.sync:
optimizer = tf.train.SyncReplicasOptimizer( optimizer = tf.train.SyncReplicasOptimizer(
optimizer, optimizer, replicas_to_aggregate=num_workers)
replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True) sync_replicas_hook = optimizer.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook) chief_hooks.append(sync_replicas_hook)
...@@ -182,7 +182,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -182,7 +182,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
} }
stacked_labels = tf.concat(labels, axis=0) stacked_labels = tf.concat(labels, axis=0)
metrics = { metrics = {
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes']) 'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes'])
} }
loss = tf.reduce_mean(tower_losses, name='loss') loss = tf.reduce_mean(tower_losses, name='loss')
...@@ -193,35 +194,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync): ...@@ -193,35 +194,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers, sync):
train_op=train_op, train_op=train_op,
training_chief_hooks=chief_hooks, training_chief_hooks=chief_hooks,
eval_metric_ops=metrics) eval_metric_ops=metrics)
return _resnet_model_fn return _resnet_model_fn
def _tower_fn(is_training, def _tower_fn(is_training, weight_decay, feature, label, data_format,
weight_decay, num_layers, batch_norm_decay, batch_norm_epsilon):
feature, """Build computation tower (Resnet).
label,
is_cpu,
num_layers,
batch_norm_decay,
batch_norm_epsilon):
"""Build computation tower for each device (CPU or GPU).
Args: Args:
is_training: true if is training graph. is_training: true if is training graph.
weight_decay: weight regularization strength, a float. weight_decay: weight regularization strength, a float.
feature: a Tensor. feature: a Tensor.
label: a Tensor. label: a Tensor.
tower_losses: a list to be appended with current tower's loss. data_format: channels_last (NHWC) or channels_first (NCHW).
tower_gradvars: a list to be appended with current tower's gradients. num_layers: number of layers, an int.
tower_preds: a list to be appended with current tower's predictions. batch_norm_decay: decay for batch normalization, a float.
is_cpu: true if build tower on CPU. batch_norm_epsilon: epsilon for batch normalization, a float.
Returns:
A tuple with the loss for the tower, the gradients and parameters, and
predictions.
""" """
data_format = 'channels_last' if is_cpu else 'channels_first'
model = cifar10_model.ResNetCifar10( model = cifar10_model.ResNetCifar10(
num_layers, num_layers,
batch_norm_decay=batch_norm_decay, batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon, batch_norm_epsilon=batch_norm_epsilon,
is_training=is_training, data_format=data_format) is_training=is_training,
data_format=data_format)
logits = model.forward_pass(feature, input_data_format='channels_last') logits = model.forward_pass(feature, input_data_format='channels_last')
tower_pred = { tower_pred = {
'classes': tf.argmax(input=logits, axis=1), 'classes': tf.argmax(input=logits, axis=1),
...@@ -241,13 +242,20 @@ def _tower_fn(is_training, ...@@ -241,13 +242,20 @@ def _tower_fn(is_training,
return tower_loss, zip(tower_grad, model_params), tower_pred return tower_loss, zip(tower_grad, model_params), tower_pred
def input_fn(data_dir, subset, num_shards, batch_size, def input_fn(data_dir,
subset,
num_shards,
batch_size,
use_distortion_for_training=True): use_distortion_for_training=True):
"""Create input graph for model. """Create input graph for model.
Args: Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'. subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training. num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns: Returns:
two lists of tensors for features and labels, each of num_shards length. two lists of tensors for features and labels, each of num_shards length.
""" """
...@@ -276,10 +284,10 @@ def input_fn(data_dir, subset, num_shards, batch_size, ...@@ -276,10 +284,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
return feature_shards, label_shards return feature_shards, label_shards
# create experiment def get_experiment_fn(data_dir,
def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, num_gpus,
use_distortion_for_training=True, variable_strategy,
sync=True): use_distortion_for_training=True):
"""Returns an Experiment function. """Returns an Experiment function.
Experiments perform training on several workers in parallel, Experiments perform training on several workers in parallel,
...@@ -291,9 +299,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -291,9 +299,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
Args: Args:
data_dir: str. Location of the data for input_fns. data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker. num_gpus: int. Number of GPUs on each worker.
is_gpu_ps: bool. If true, average gradients on GPUs. variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet. use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
sync: bool. If true synchronizes variable updates across workers.
Returns: Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) -> A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
tf.contrib.learn.Experiment. tf.contrib.learn.Experiment.
...@@ -302,6 +310,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -302,6 +310,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
methods on Experiment (train, evaluate) based on information methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`. about the current runner in `run_config`.
""" """
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
"""Returns an Experiment.""" """Returns an Experiment."""
# Create estimator. # Create estimator.
...@@ -311,28 +320,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -311,28 +320,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
subset='train', subset='train',
num_shards=num_gpus, num_shards=num_gpus,
batch_size=hparams.train_batch_size, batch_size=hparams.train_batch_size,
use_distortion_for_training=use_distortion_for_training use_distortion_for_training=use_distortion_for_training)
)
eval_input_fn = functools.partial( eval_input_fn = functools.partial(
input_fn, input_fn,
data_dir, data_dir,
subset='eval', subset='eval',
batch_size=hparams.eval_batch_size, batch_size=hparams.eval_batch_size,
num_shards=num_gpus num_shards=num_gpus)
)
num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval') num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval')
if num_eval_examples % hparams.eval_batch_size != 0: if num_eval_examples % hparams.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size') raise ValueError(
'validation set size must be multiple of eval_batch_size')
train_steps = hparams.train_steps train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook( examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10) hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
...@@ -340,11 +347,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -340,11 +347,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
hooks = [logging_hook, examples_sec_hook] hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=get_model_fn( model_fn=get_model_fn(num_gpus, variable_strategy,
num_gpus, is_gpu_ps, run_config.num_worker_replicas or 1, sync), run_config.num_worker_replicas or 1),
config=run_config, config=run_config,
params=hparams params=hparams)
)
# Create experiment. # Create experiment.
experiment = tf.contrib.learn.Experiment( experiment = tf.contrib.learn.Experiment(
...@@ -356,45 +362,31 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -356,45 +362,31 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
# Adding hooks to be used by the estimator on training modes # Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks) experiment.extend_train_hooks(hooks)
return experiment return experiment
return _experiment_fn return _experiment_fn
def main(job_dir, def main(job_dir, data_dir, num_gpus, variable_strategy,
data_dir, use_distortion_for_training, log_device_placement, num_intra_threads,
num_gpus,
variable_strategy,
use_distortion_for_training,
log_device_placement,
num_intra_threads,
sync,
**hparams): **hparams):
# The env variable is on deprecation path, default is set to off. # The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Session configuration. # Session configuration.
sess_config = tf.ConfigProto( sess_config = tf.ConfigProto(
allow_soft_placement=True, allow_soft_placement=True,
log_device_placement=log_device_placement, log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads, intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions( gpu_options=tf.GPUOptions(force_gpu_compatible=True))
force_gpu_compatible=True
)
)
config = cifar10_utils.RunConfig( config = cifar10_utils.RunConfig(
session_config=sess_config, session_config=sess_config, model_dir=job_dir)
model_dir=job_dir)
tf.contrib.learn.learn_runner.run( tf.contrib.learn.learn_runner.run(
get_experiment_fn( get_experiment_fn(data_dir, num_gpus, variable_strategy,
data_dir, use_distortion_for_training),
num_gpus,
variable_strategy,
use_distortion_for_training,
sync
),
run_config=config, run_config=config,
hparams=tf.contrib.training.HParams(**hparams) hparams=tf.contrib.training.HParams(**hparams))
)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -403,63 +395,53 @@ if __name__ == '__main__': ...@@ -403,63 +395,53 @@ if __name__ == '__main__':
'--data-dir', '--data-dir',
type=str, type=str,
required=True, required=True,
help='The directory where the CIFAR-10 input data is stored.' help='The directory where the CIFAR-10 input data is stored.')
)
parser.add_argument( parser.add_argument(
'--job-dir', '--job-dir',
type=str, type=str,
required=True, required=True,
help='The directory where the model will be stored.' help='The directory where the model will be stored.')
)
parser.add_argument( parser.add_argument(
'--variable-strategy', '--variable-strategy',
choices=['CPU', 'GPU'], choices=['CPU', 'GPU'],
type=str, type=str,
default='CPU', default='CPU',
help='Where to locate variable operations' help='Where to locate variable operations')
)
parser.add_argument( parser.add_argument(
'--num-gpus', '--num-gpus',
type=int, type=int,
default=1, default=1,
help='The number of gpus used. Uses only CPU if set to 0.' help='The number of gpus used. Uses only CPU if set to 0.')
)
parser.add_argument( parser.add_argument(
'--num-layers', '--num-layers',
type=int, type=int,
default=44, default=44,
help='The number of layers of the model.' help='The number of layers of the model.')
)
parser.add_argument( parser.add_argument(
'--train-steps', '--train-steps',
type=int, type=int,
default=80000, default=80000,
help='The number of steps to use for training.' help='The number of steps to use for training.')
)
parser.add_argument( parser.add_argument(
'--train-batch-size', '--train-batch-size',
type=int, type=int,
default=128, default=128,
help='Batch size for training.' help='Batch size for training.')
)
parser.add_argument( parser.add_argument(
'--eval-batch-size', '--eval-batch-size',
type=int, type=int,
default=100, default=100,
help='Batch size for validation.' help='Batch size for validation.')
)
parser.add_argument( parser.add_argument(
'--momentum', '--momentum',
type=float, type=float,
default=0.9, default=0.9,
help='Momentum for MomentumOptimizer.' help='Momentum for MomentumOptimizer.')
)
parser.add_argument( parser.add_argument(
'--weight-decay', '--weight-decay',
type=float, type=float,
default=2e-4, default=2e-4,
help='Weight decay for convolutions.' help='Weight decay for convolutions.')
)
parser.add_argument( parser.add_argument(
'--learning-rate', '--learning-rate',
type=float, type=float,
...@@ -468,22 +450,19 @@ if __name__ == '__main__': ...@@ -468,22 +450,19 @@ if __name__ == '__main__':
This is the inital learning rate value. The learning rate will decrease This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in during training. For more details check the model_fn implementation in
this file.\ this file.\
""" """)
)
parser.add_argument( parser.add_argument(
'--use-distortion-for-training', '--use-distortion-for-training',
type=bool, type=bool,
default=True, default=True,
help='If doing image distortion for training.' help='If doing image distortion for training.')
)
parser.add_argument( parser.add_argument(
'--sync', '--sync',
action='store_true', action='store_true',
default=False, default=False,
help="""\ help="""\
If present when running in a distributed environment will run on sync mode.\ If present when running in a distributed environment will run on sync mode.\
""" """)
)
parser.add_argument( parser.add_argument(
'--num-intra-threads', '--num-intra-threads',
type=int, type=int,
...@@ -492,8 +471,7 @@ if __name__ == '__main__': ...@@ -492,8 +471,7 @@ if __name__ == '__main__':
Number of threads to use for intra-op parallelism. When training on CPU Number of threads to use for intra-op parallelism. When training on CPU
set to 0 to have the system pick the appropriate number or alternatively set to 0 to have the system pick the appropriate number or alternatively
set it to the number of physical CPU cores.\ set it to the number of physical CPU cores.\
""" """)
)
parser.add_argument( parser.add_argument(
'--num-inter-threads', '--num-inter-threads',
type=int, type=int,
...@@ -501,34 +479,37 @@ if __name__ == '__main__': ...@@ -501,34 +479,37 @@ if __name__ == '__main__':
help="""\ help="""\
Number of threads to use for inter-op parallelism. If set to 0, the Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.\ system will pick an appropriate number.\
""" """)
) parser.add_argument(
'--data-format',
type=str,
default=None,
help="""\
If not set, the data format best for the training device is used.
Allowed values: channels_first (NCHW) channels_last (NHWC).\
""")
parser.add_argument( parser.add_argument(
'--log-device-placement', '--log-device-placement',
action='store_true', action='store_true',
default=False, default=False,
help='Whether to log device placement.' help='Whether to log device placement.')
)
parser.add_argument( parser.add_argument(
'--batch-norm-decay', '--batch-norm-decay',
type=float, type=float,
default=0.997, default=0.997,
help='Decay for batch norm.' help='Decay for batch norm.')
)
parser.add_argument( parser.add_argument(
'--batch-norm-epsilon', '--batch-norm-epsilon',
type=float, type=float,
default=1e-5, default=1e-5,
help='Epsilon for batch norm.' help='Epsilon for batch norm.')
)
args = parser.parse_args() args = parser.parse_args()
if args.num_gpus < 0: if args.num_gpus < 0:
raise ValueError( raise ValueError(
'Invalid GPU count: \"--num-gpus\" must be 0 or a positive integer.') 'Invalid GPU count: \"--num-gpus\" must be 0 or a positive integer.')
if args.num_gpus == 0 and args.variable_strategy == 'GPU': if args.num_gpus == 0 and args.variable_strategy == 'GPU':
raise ValueError( raise ValueError('num-gpus=0, CPU must be used as parameter server. Set'
'num-gpus=0, CPU must be used as parameter server. Set'
'--variable-strategy=CPU.') '--variable-strategy=CPU.')
if (args.num_layers - 2) % 6 != 0: if (args.num_layers - 2) % 6 != 0:
raise ValueError('Invalid --num-layers parameter.') raise ValueError('Invalid --num-layers parameter.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment