Commit afa7954d authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Refactoring and adding sync mode

parent 74ecc048
......@@ -36,7 +36,6 @@ import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner # run the experiment
import cifar10
import cifar10_model
......@@ -45,13 +44,13 @@ tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('data_dir', 'cifar10',
tf.flags.DEFINE_string('data_dir', '',
'The directory where the CIFAR-10 input data is stored.')
tf.flags.DEFINE_string('model_dir', 'output2_2',
tf.flags.DEFINE_string('model_dir', '',
'The directory where the model will be stored.')
tf.flags.DEFINE_boolean('is_cpu_ps', True,
tf.flags.DEFINE_boolean('is_cpu_ps', False,
'If using CPU as the parameter server.')
tf.flags.DEFINE_integer('num_gpus', 1,
......@@ -59,12 +58,12 @@ tf.flags.DEFINE_integer('num_gpus', 1,
tf.flags.DEFINE_integer('num_layers', 44, 'The number of layers of the model.')
tf.flags.DEFINE_integer('train_batch_size', 1024, 'Batch size for training.')
tf.flags.DEFINE_integer('train_steps', 10000,
'The number of steps to use for training.')
tf.flags.DEFINE_integer('train_steps', (50000.0/FLAGS.train_batch_size) * 40,
'The number of steps to use for training.') # 40 epochs
tf.flags.DEFINE_integer('train_batch_size', 128, 'Batch size for training.')
tf.flags.DEFINE_integer('eval_batch_size', 200, 'Batch size for validation.')
tf.flags.DEFINE_integer('eval_batch_size', 100, 'Batch size for validation.')
tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
......@@ -73,6 +72,18 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
'otherwise will run training and evaluation'
'using the estimator interface')
tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment'
'will run on sync mode')
tf.flags.DEFINE_integer('num_workers', 1,
'Number of workers')
# Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1,
"""Number of threads to use for intra-op parallelism.
......@@ -138,6 +149,7 @@ def _create_device_setter(is_cpu_ps, worker):
gpus = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)]
return ParamServerDeviceSetter(worker, gpus)
def _resnet_model_fn(features, labels, mode):
"""Resnet model body.
......@@ -171,24 +183,24 @@ def _resnet_model_fn(features, labels, mode):
worker = '/gpu:%d' % i
device_setter = _create_device_setter(is_cpu_ps, worker)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
_tower_fn(is_training, weight_decay, tower_features[i],
tower_labels[i], tower_losses, tower_gradvars,
tower_preds, False)
if i == 0:
# Only trigger batch_norm moving mean and variance update from the
# 1st tower. Ideally, we should grab the updates from all towers
# but these stats accumulate extremely fast so we can ignore the
# other stats from the other towers without significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
_tower_fn(is_training, weight_decay, tower_features[i],
tower_labels[i], tower_losses, tower_gradvars,
tower_preds, False)
if i == 0:
# Only trigger batch_norm moving mean and variance update from the
# 1st tower. Ideally, we should grab the updates from all towers
# but these stats accumulate extremely fast so we can ignore the
# other stats from the other towers without significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
else:
with tf.variable_scope('resnet'), tf.device('/cpu:0'):
with tf.name_scope('tower_cpu') as name_scope:
_tower_fn(is_training, weight_decay, tower_features[0], tower_labels[0],
tower_losses, tower_gradvars, tower_preds, True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
_tower_fn(is_training, weight_decay, tower_features[0], tower_labels[0],
tower_losses, tower_gradvars, tower_preds, True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
# Now compute global loss and gradients.
gradvars = []
......@@ -200,52 +212,59 @@ def _resnet_model_fn(features, labels, mode):
with tf.name_scope('gradient_averaging'):
loss = tf.reduce_mean(tower_losses)
for zipped_gradvars in zip(*tower_gradvars):
# Averaging one var's gradients computed from multiple towers
var = zipped_gradvars[0][1]
grads = [gv[0] for gv in zipped_gradvars]
with tf.device(var.device):
if len(grads) == 1:
avg_grad = grads[0]
else:
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var))
# Averaging one var's gradients computed from multiple towers
var = zipped_gradvars[0][1]
grads = [gv[0] for gv in zipped_gradvars]
with tf.device(var.device):
if len(grads) == 1:
avg_grad = grads[0]
else:
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var))
# Suggested learning rate scheduling from
# https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
# users could apply other scheduling.
num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
'train') // FLAGS.train_batch_size
'train') // FLAGS.train_batch_size
boundaries = [
num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64)
num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64)
]
staged_lr = [0.1, 0.01, 0.001, 0.0002]
global_step = tf.train.get_global_step()
learning_rate = tf.train.piecewise_constant(global_step,
boundaries, staged_lr)
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
momentum=momentum)
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
if FLAGS.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.num_workers)
sync_replicas_hook = opt.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook)
# Create single grouped train op
train_op = [
optimizer.apply_gradients(
gradvars, global_step=global_step)
optimizer.apply_gradients(
gradvars, global_step=tf.train.get_global_step())
]
train_op.extend(update_ops)
train_op = tf.group(*train_op)
predictions = {
'classes':
tf.concat([p['classes'] for p in tower_preds], axis=0),
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0)
'classes':
tf.concat([p['classes'] for p in tower_preds], axis=0),
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0)
}
stacked_labels = tf.concat(labels, axis=0)
metrics = {
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes'])
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes'])
}
return tf.estimator.EstimatorSpec(
......@@ -253,6 +272,7 @@ def _resnet_model_fn(features, labels, mode):
predictions=predictions,
loss=loss,
train_op=train_op,
training_chief_hooks=chief_hooks,
eval_metric_ops=metrics)
......@@ -283,7 +303,6 @@ def _tower_fn(is_training, weight_decay, feature, label, tower_losses,
tower_loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=label)
tower_loss = tf.reduce_mean(tower_loss)
tower_losses.append(tower_loss)
model_params = tf.trainable_variables()
tower_loss += weight_decay * tf.add_n(
......@@ -303,43 +322,16 @@ def input_fn(subset, num_shards):
Returns:
two lists of tensors for features and labels, each of num_shards length.
"""
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir)
is_training = (subset == 'train')
if is_training:
if subset == 'train':
batch_size = FLAGS.train_batch_size
else:
elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size
with tf.device('/cpu:0'), tf.name_scope('batching'):
# CPU loads all data from disk since there're only 60k 32*32 RGB images.
all_images, all_labels = dataset.read_all_data(subset)
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(all_images, all_labels))
dataset = dataset.map(
lambda x, y: (tf.cast(x, tf.float32), tf.cast(y, tf.int32)),
num_threads=2,
output_buffer_size=batch_size)
# Image preprocessing.
def _preprocess(image, label):
# If GPU is available, NHWC to NCHW transpose is done in ResNetCifar10
# class, not included in preprocessing.
return cifar10.Cifar10DataSet.preprocess(
image, is_training, FLAGS.use_distortion_for_training), label
dataset = dataset.map(
_preprocess, num_threads=batch_size, output_buffer_size=2 * batch_size)
# Repeat infinitely.
dataset = dataset.repeat()
if is_training:
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(
cifar10.Cifar10DataSet.num_examples_per_epoch(subset) *
min_fraction_of_examples_in_queue)
# Ensure that the capacity is sufficiently large to provide good random
# shuffling
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
else:
raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'')
with tf.device('/cpu:0'):
use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
image_batch, label_batch = dataset.make_batch(batch_size)
if num_shards <= 1:
# No GPU available or only 1 GPU.
return [image_batch], [label_batch]
......@@ -360,21 +352,26 @@ def input_fn(subset, num_shards):
label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards
def create_experiment_fn(train_input, test_input, hooks):
def _experiment_fn(run_config, hparams):
estimator = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config,
model_dir=FLAGS.model_dir)
experiment = tf.contrib.learn.Experiment(
estimator,
train_input_fn=train_input,
eval_input_fn=test_input,
train_steps=FLAGS.train_steps)
experiment.extend_train_hooks(hooks)
return experiment
# create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps,
train_hooks):
def _experiment_fn(run_config, hparams):
del hparams # unused arg
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
experiment = tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
# adding hooks to estimator on training mode
experiment.extend_train_hooks(train_hooks)
return experiment
return _experiment_fn
return _experiment_fn
def main(unused_argv):
# The env variable is on deprecation path, default is set to off.
......@@ -397,26 +394,52 @@ def main(unused_argv):
if num_eval_examples % FLAGS.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size')
config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
# session configuration
sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config)
train_input = functools.partial(input_fn, subset='train', num_shards=FLAGS.num_gpus)
test_input = functools.partial(input_fn, subset='eval', num_shards=FLAGS.num_gpus)
tensors_to_log = {'learning_rate': 'learning_rate'}
# log learning_rate
tensors_to_log = {'learning_rate': 'learning_rate'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook]
# run experiment
learn_runner.run(create_experiment_fn(train_input, test_input, hooks), run_config=config)
if FLAGS.run_experiment:
config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
config = config.replace(session_config=sess_config)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps,
[logging_hook]), run_config=config)
else:
config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
print('Starting to train...')
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=[logging_hook])
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment