"vscode:/vscode.git/clone" did not exist on "aa645e714138185daecebfc3f441fea716833e41"
Commit add2845a authored by Toby Boyd's avatar Toby Boyd
Browse files

Style cleanup

parent a7531875
...@@ -74,8 +74,8 @@ class Cifar10DataSet(object): ...@@ -74,8 +74,8 @@ class Cifar10DataSet(object):
dataset = tf.contrib.data.TFRecordDataset(filenames).repeat() dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
# Parse records. # Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size, dataset = dataset.map(
output_buffer_size=2 * batch_size) self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)
# Potentially shuffle records. # Potentially shuffle records.
if self.subset == 'train': if self.subset == 'train':
......
...@@ -29,25 +29,23 @@ from __future__ import division ...@@ -29,25 +29,23 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import collections
import functools import functools
import itertools import itertools
import os import os
import six
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import cifar10 import cifar10
import cifar10_model import cifar10_model
import cifar10_utils import cifar10_utils
import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
def get_model_fn(num_gpus, variable_strategy, num_workers): def get_model_fn(num_gpus, variable_strategy, data_format, num_workers):
"""Returns a function that will build the resnet model."""
def _resnet_model_fn(features, labels, mode, params): def _resnet_model_fn(features, labels, mode, params):
"""Resnet model body. """Resnet model body.
...@@ -92,21 +90,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -92,21 +90,13 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
ps_device_type='gpu', ps_device_type='gpu',
worker_device=worker_device, worker_device=worker_device,
ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy( ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
num_gpus, num_gpus, tf.contrib.training.byte_size_load_fn))
tf.contrib.training.byte_size_load_fn
)
)
with tf.variable_scope('resnet', reuse=bool(i != 0)): with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope: with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter): with tf.device(device_setter):
loss, gradvars, preds = _tower_fn( loss, gradvars, preds = _tower_fn(
is_training, is_training, weight_decay, tower_features[i], tower_labels[i],
weight_decay, data_format, params['num_layers'], params['batch_norm_decay'],
tower_features[i],
tower_labels[i],
(device_type == 'cpu'),
params['num_layers'],
params['batch_norm_decay'],
params['batch_norm_epsilon']) params['batch_norm_epsilon'])
tower_losses.append(loss) tower_losses.append(loss)
tower_gradvars.append(gradvars) tower_gradvars.append(gradvars)
...@@ -137,7 +127,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -137,7 +127,6 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads)) avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var)) gradvars.append((avg_grad, var))
# Device that runs the ops to apply global gradient updates. # Device that runs the ops to apply global gradient updates.
consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0' consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
with tf.device(consolidation_device): with tf.device(consolidation_device):
...@@ -163,8 +152,7 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -163,8 +152,7 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
chief_hooks = [] chief_hooks = []
if params['sync']: if params['sync']:
optimizer = tf.train.SyncReplicasOptimizer( optimizer = tf.train.SyncReplicasOptimizer(
optimizer, optimizer, replicas_to_aggregate=num_workers)
replicas_to_aggregate=num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True) sync_replicas_hook = optimizer.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook) chief_hooks.append(sync_replicas_hook)
...@@ -184,7 +172,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -184,7 +172,8 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
} }
stacked_labels = tf.concat(labels, axis=0) stacked_labels = tf.concat(labels, axis=0)
metrics = { metrics = {
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes']) 'accuracy':
tf.metrics.accuracy(stacked_labels, predictions['classes'])
} }
loss = tf.reduce_mean(tower_losses, name='loss') loss = tf.reduce_mean(tower_losses, name='loss')
...@@ -195,35 +184,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers): ...@@ -195,35 +184,35 @@ def get_model_fn(num_gpus, variable_strategy, num_workers):
train_op=train_op, train_op=train_op,
training_chief_hooks=chief_hooks, training_chief_hooks=chief_hooks,
eval_metric_ops=metrics) eval_metric_ops=metrics)
return _resnet_model_fn return _resnet_model_fn
def _tower_fn(is_training, def _tower_fn(is_training, weight_decay, feature, label, data_format,
weight_decay, num_layers, batch_norm_decay, batch_norm_epsilon):
feature, """Build computation tower (Resnet).
label,
is_cpu,
num_layers,
batch_norm_decay,
batch_norm_epsilon):
"""Build computation tower for each device (CPU or GPU).
Args: Args:
is_training: true if is training graph. is_training: true if is training graph.
weight_decay: weight regularization strength, a float. weight_decay: weight regularization strength, a float.
feature: a Tensor. feature: a Tensor.
label: a Tensor. label: a Tensor.
tower_losses: a list to be appended with current tower's loss. data_format: channels_last (NHWC) or channels_first (NCHW).
tower_gradvars: a list to be appended with current tower's gradients. num_layers: number of layers, an int.
tower_preds: a list to be appended with current tower's predictions. batch_norm_decay: decay for batch normalization, a float.
is_cpu: true if build tower on CPU. batch_norm_epsilon: epsilon for batch normalization, a float.
Returns:
A tuple with the loss for the tower, the gradients and parameters, and
predictions.
""" """
data_format = 'channels_last' if is_cpu else 'channels_first'
model = cifar10_model.ResNetCifar10( model = cifar10_model.ResNetCifar10(
num_layers, num_layers,
batch_norm_decay=batch_norm_decay, batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon, batch_norm_epsilon=batch_norm_epsilon,
is_training=is_training, data_format=data_format) is_training=is_training,
data_format=data_format)
logits = model.forward_pass(feature, input_data_format='channels_last') logits = model.forward_pass(feature, input_data_format='channels_last')
tower_pred = { tower_pred = {
'classes': tf.argmax(input=logits, axis=1), 'classes': tf.argmax(input=logits, axis=1),
...@@ -243,13 +232,20 @@ def _tower_fn(is_training, ...@@ -243,13 +232,20 @@ def _tower_fn(is_training,
return tower_loss, zip(tower_grad, model_params), tower_pred return tower_loss, zip(tower_grad, model_params), tower_pred
def input_fn(data_dir, subset, num_shards, batch_size, def input_fn(data_dir,
subset,
num_shards,
batch_size,
use_distortion_for_training=True): use_distortion_for_training=True):
"""Create input graph for model. """Create input graph for model.
Args: Args:
data_dir: Directory where TFRecords representing the dataset are located.
subset: one of 'train', 'validate' and 'eval'. subset: one of 'train', 'validate' and 'eval'.
num_shards: num of towers participating in data-parallel training. num_shards: num of towers participating in data-parallel training.
batch_size: total batch size for training to be divided by the number of
shards.
use_distortion_for_training: True to use distortions.
Returns: Returns:
two lists of tensors for features and labels, each of num_shards length. two lists of tensors for features and labels, each of num_shards length.
""" """
...@@ -279,7 +275,10 @@ def input_fn(data_dir, subset, num_shards, batch_size, ...@@ -279,7 +275,10 @@ def input_fn(data_dir, subset, num_shards, batch_size,
# create experiment # create experiment
def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, def get_experiment_fn(data_dir,
num_gpus,
variable_strategy,
data_format,
use_distortion_for_training=True): use_distortion_for_training=True):
"""Returns an Experiment function. """Returns an Experiment function.
...@@ -292,7 +291,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -292,7 +291,9 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
Args: Args:
data_dir: str. Location of the data for input_fns. data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker. num_gpus: int. Number of GPUs on each worker.
is_gpu_ps: bool. If true, average gradients on GPUs. variable_strategy: String. CPU to use CPU as the parameter server
and GPU to use the GPUs as the parameter server.
data_format: String. channels_first or channels_last.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet. use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
Returns: Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) -> A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
...@@ -302,6 +303,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -302,6 +303,7 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
methods on Experiment (train, evaluate) based on information methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`. about the current runner in `run_config`.
""" """
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
"""Returns an Experiment.""" """Returns an Experiment."""
# Create estimator. # Create estimator.
...@@ -311,28 +313,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -311,28 +313,26 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
subset='train', subset='train',
num_shards=num_gpus, num_shards=num_gpus,
batch_size=hparams.train_batch_size, batch_size=hparams.train_batch_size,
use_distortion_for_training=use_distortion_for_training use_distortion_for_training=use_distortion_for_training)
)
eval_input_fn = functools.partial( eval_input_fn = functools.partial(
input_fn, input_fn,
data_dir, data_dir,
subset='eval', subset='eval',
batch_size=hparams.eval_batch_size, batch_size=hparams.eval_batch_size,
num_shards=num_gpus num_shards=num_gpus)
)
num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval') num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval')
if num_eval_examples % hparams.eval_batch_size != 0: if num_eval_examples % hparams.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size') raise ValueError(
'validation set size must be multiple of eval_batch_size')
train_steps = hparams.train_steps train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook( examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10) hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate', tensors_to_log = {'learning_rate': 'learning_rate', 'loss': 'loss'}
'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
...@@ -340,11 +340,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -340,11 +340,10 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
hooks = [logging_hook, examples_sec_hook] hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=get_model_fn( model_fn=get_model_fn(num_gpus, variable_strategy, data_format,
num_gpus, is_gpu_ps, run_config.num_worker_replicas or 1), run_config.num_worker_replicas or 1),
config=run_config, config=run_config,
params=vars(hparams) params=vars(hparams))
)
# Create experiment. # Create experiment.
experiment = tf.contrib.learn.Experiment( experiment = tf.contrib.learn.Experiment(
...@@ -356,43 +355,40 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps, ...@@ -356,43 +355,40 @@ def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
# Adding hooks to be used by the estimator on training modes # Adding hooks to be used by the estimator on training modes
experiment.extend_train_hooks(hooks) experiment.extend_train_hooks(hooks)
return experiment return experiment
return _experiment_fn return _experiment_fn
def main(job_dir, def main(job_dir, data_dir, num_gpus, variable_strategy, data_format,
data_dir, use_distortion_for_training, log_device_placement, num_intra_threads,
num_gpus,
variable_strategy,
use_distortion_for_training,
log_device_placement,
num_intra_threads,
**hparams): **hparams):
# The env variable is on deprecation path, default is set to off. # The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_SYNC_ON_FINISH'] = '0'
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# channels first (NCHW) is normally optimal on GPU and channels last (NHWC)
# on CPU. The exception is Intel MKL on CPU which is optimal with
# channels_last.
if not data_format:
if num_gpus == 0:
data_format = 'channels_last'
else:
data_format = 'channels_first'
# Session configuration. # Session configuration.
sess_config = tf.ConfigProto( sess_config = tf.ConfigProto(
allow_soft_placement=True, allow_soft_placement=True,
log_device_placement=log_device_placement, log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads, intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions( gpu_options=tf.GPUOptions(force_gpu_compatible=True))
force_gpu_compatible=True
)
)
config = cifar10_utils.RunConfig( config = cifar10_utils.RunConfig(
session_config=sess_config, session_config=sess_config, model_dir=job_dir)
model_dir=job_dir)
tf.contrib.learn.learn_runner.run( tf.contrib.learn.learn_runner.run(
get_experiment_fn( get_experiment_fn(data_dir, num_gpus, variable_strategy, data_format,
data_dir, use_distortion_for_training),
num_gpus,
variable_strategy,
use_distortion_for_training
),
run_config=config, run_config=config,
hparams=tf.contrib.training.HParams(**hparams) hparams=tf.contrib.training.HParams(**hparams))
)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -401,63 +397,53 @@ if __name__ == '__main__': ...@@ -401,63 +397,53 @@ if __name__ == '__main__':
'--data-dir', '--data-dir',
type=str, type=str,
required=True, required=True,
help='The directory where the CIFAR-10 input data is stored.' help='The directory where the CIFAR-10 input data is stored.')
)
parser.add_argument( parser.add_argument(
'--job-dir', '--job-dir',
type=str, type=str,
required=True, required=True,
help='The directory where the model will be stored.' help='The directory where the model will be stored.')
)
parser.add_argument( parser.add_argument(
'--variable-strategy', '--variable-strategy',
choices=['CPU', 'GPU'], choices=['CPU', 'GPU'],
type=str, type=str,
default='CPU', default='CPU',
help='Where to locate variable operations' help='Where to locate variable operations')
)
parser.add_argument( parser.add_argument(
'--num-gpus', '--num-gpus',
type=int, type=int,
default=1, default=1,
help='The number of gpus used. Uses only CPU if set to 0.' help='The number of gpus used. Uses only CPU if set to 0.')
)
parser.add_argument( parser.add_argument(
'--num-layers', '--num-layers',
type=int, type=int,
default=44, default=44,
help='The number of layers of the model.' help='The number of layers of the model.')
)
parser.add_argument( parser.add_argument(
'--train-steps', '--train-steps',
type=int, type=int,
default=80000, default=80000,
help='The number of steps to use for training.' help='The number of steps to use for training.')
)
parser.add_argument( parser.add_argument(
'--train-batch-size', '--train-batch-size',
type=int, type=int,
default=128, default=128,
help='Batch size for training.' help='Batch size for training.')
)
parser.add_argument( parser.add_argument(
'--eval-batch-size', '--eval-batch-size',
type=int, type=int,
default=100, default=100,
help='Batch size for validation.' help='Batch size for validation.')
)
parser.add_argument( parser.add_argument(
'--momentum', '--momentum',
type=float, type=float,
default=0.9, default=0.9,
help='Momentum for MomentumOptimizer.' help='Momentum for MomentumOptimizer.')
)
parser.add_argument( parser.add_argument(
'--weight-decay', '--weight-decay',
type=float, type=float,
default=2e-4, default=2e-4,
help='Weight decay for convolutions.' help='Weight decay for convolutions.')
)
parser.add_argument( parser.add_argument(
'--learning-rate', '--learning-rate',
type=float, type=float,
...@@ -466,22 +452,19 @@ if __name__ == '__main__': ...@@ -466,22 +452,19 @@ if __name__ == '__main__':
This is the inital learning rate value. The learning rate will decrease This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in during training. For more details check the model_fn implementation in
this file.\ this file.\
""" """)
)
parser.add_argument( parser.add_argument(
'--use-distortion-for-training', '--use-distortion-for-training',
type=bool, type=bool,
default=True, default=True,
help='If doing image distortion for training.' help='If doing image distortion for training.')
)
parser.add_argument( parser.add_argument(
'--sync', '--sync',
action='store_true', action='store_true',
default=False, default=False,
help="""\ help="""\
If present when running in a distributed environment will run on sync mode.\ If present when running in a distributed environment will run on sync mode.\
""" """)
)
parser.add_argument( parser.add_argument(
'--num-intra-threads', '--num-intra-threads',
type=int, type=int,
...@@ -492,8 +475,7 @@ if __name__ == '__main__': ...@@ -492,8 +475,7 @@ if __name__ == '__main__':
example CPU only handles the input pipeline and gradient aggregation example CPU only handles the input pipeline and gradient aggregation
(when --is-cpu-ps). Ops that could potentially benefit from intra-op (when --is-cpu-ps). Ops that could potentially benefit from intra-op
parallelism are scheduled to run on GPUs.\ parallelism are scheduled to run on GPUs.\
""" """)
)
parser.add_argument( parser.add_argument(
'--num-inter-threads', '--num-inter-threads',
type=int, type=int,
...@@ -501,26 +483,30 @@ if __name__ == '__main__': ...@@ -501,26 +483,30 @@ if __name__ == '__main__':
help="""\ help="""\
Number of threads to use for inter-op parallelism. If set to 0, the Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.\ system will pick an appropriate number.\
""" """)
) parser.add_argument(
'--data-format',
type=str,
default=None,
help="""\
If not set, the data format best for the training device is used.
Allowed values: channels_first (NCHW) channels_last (NHWC).\
""")
parser.add_argument( parser.add_argument(
'--log-device-placement', '--log-device-placement',
action='store_true', action='store_true',
default=False, default=False,
help='Whether to log device placement.' help='Whether to log device placement.')
)
parser.add_argument( parser.add_argument(
'--batch-norm-decay', '--batch-norm-decay',
type=float, type=float,
default=0.997, default=0.997,
help='Decay for batch norm.' help='Decay for batch norm.')
)
parser.add_argument( parser.add_argument(
'--batch-norm-epsilon', '--batch-norm-epsilon',
type=float, type=float,
default=1e-5, default=1e-5,
help='Epsilon for batch norm.' help='Epsilon for batch norm.')
)
args = parser.parse_args() args = parser.parse_args()
if args.num_gpus < 0: if args.num_gpus < 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment