Commit 74ecc048 authored by Marianne Linhares Monteiro's avatar Marianne Linhares Monteiro Committed by GitHub
Browse files

Refactoring and adding sync mode

parent 28328ae3
......@@ -36,6 +36,7 @@ import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import learn_runner # run the experiment
import cifar10
import cifar10_model
......@@ -44,13 +45,13 @@ tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('data_dir', '',
tf.flags.DEFINE_string('data_dir', 'cifar10',
'The directory where the CIFAR-10 input data is stored.')
tf.flags.DEFINE_string('model_dir', '',
tf.flags.DEFINE_string('model_dir', 'output2_2',
'The directory where the model will be stored.')
tf.flags.DEFINE_boolean('is_cpu_ps', False,
tf.flags.DEFINE_boolean('is_cpu_ps', True,
'If using CPU as the parameter server.')
tf.flags.DEFINE_integer('num_gpus', 1,
......@@ -58,12 +59,12 @@ tf.flags.DEFINE_integer('num_gpus', 1,
tf.flags.DEFINE_integer('num_layers', 44, 'The number of layers of the model.')
tf.flags.DEFINE_integer('train_steps', 10000,
'The number of steps to use for training.')
tf.flags.DEFINE_integer('train_batch_size', 1024, 'Batch size for training.')
tf.flags.DEFINE_integer('train_batch_size', 128, 'Batch size for training.')
tf.flags.DEFINE_integer('train_steps', (50000.0/FLAGS.train_batch_size) * 40,
'The number of steps to use for training.') # 40 epochs
tf.flags.DEFINE_integer('eval_batch_size', 100, 'Batch size for validation.')
tf.flags.DEFINE_integer('eval_batch_size', 200, 'Batch size for validation.')
tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
......@@ -71,10 +72,6 @@ tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
'otherwise will run training and evaluation'
'using the estimator interface')
# Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1,
......@@ -141,7 +138,6 @@ def _create_device_setter(is_cpu_ps, worker):
gpus = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)]
return ParamServerDeviceSetter(worker, gpus)
def _resnet_model_fn(features, labels, mode):
"""Resnet model body.
......@@ -175,24 +171,24 @@ def _resnet_model_fn(features, labels, mode):
worker = '/gpu:%d' % i
device_setter = _create_device_setter(is_cpu_ps, worker)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
_tower_fn(is_training, weight_decay, tower_features[i],
tower_labels[i], tower_losses, tower_gradvars,
tower_preds, False)
if i == 0:
# Only trigger batch_norm moving mean and variance update from the
# 1st tower. Ideally, we should grab the updates from all towers
# but these stats accumulate extremely fast so we can ignore the
# other stats from the other towers without significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
with tf.name_scope('tower_%d' % i) as name_scope:
with tf.device(device_setter):
_tower_fn(is_training, weight_decay, tower_features[i],
tower_labels[i], tower_losses, tower_gradvars,
tower_preds, False)
if i == 0:
# Only trigger batch_norm moving mean and variance update from the
# 1st tower. Ideally, we should grab the updates from all towers
# but these stats accumulate extremely fast so we can ignore the
# other stats from the other towers without significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
else:
with tf.variable_scope('resnet'), tf.device('/cpu:0'):
with tf.name_scope('tower_cpu') as name_scope:
_tower_fn(is_training, weight_decay, tower_features[0], tower_labels[0],
tower_losses, tower_gradvars, tower_preds, True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
_tower_fn(is_training, weight_decay, tower_features[0], tower_labels[0],
tower_losses, tower_gradvars, tower_preds, True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
# Now compute global loss and gradients.
gradvars = []
......@@ -204,51 +200,52 @@ def _resnet_model_fn(features, labels, mode):
with tf.name_scope('gradient_averaging'):
loss = tf.reduce_mean(tower_losses)
for zipped_gradvars in zip(*tower_gradvars):
# Averaging one var's gradients computed from multiple towers
var = zipped_gradvars[0][1]
grads = [gv[0] for gv in zipped_gradvars]
with tf.device(var.device):
if len(grads) == 1:
avg_grad = grads[0]
else:
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var))
# Averaging one var's gradients computed from multiple towers
var = zipped_gradvars[0][1]
grads = [gv[0] for gv in zipped_gradvars]
with tf.device(var.device):
if len(grads) == 1:
avg_grad = grads[0]
else:
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var))
# Suggested learning rate scheduling from
# https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
# users could apply other scheduling.
num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
'train') // FLAGS.train_batch_size
'train') // FLAGS.train_batch_size
boundaries = [
num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64)
num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64)
]
staged_lr = [0.1, 0.01, 0.001, 0.0002]
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
boundaries, staged_lr)
global_step = tf.train.get_global_step()
learning_rate = tf.train.piecewise_constant(global_step,
boundaries, staged_lr)
# Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
momentum=momentum)
# Create single grouped train op
train_op = [
optimizer.apply_gradients(
gradvars, global_step=tf.train.get_global_step())
optimizer.apply_gradients(
gradvars, global_step=global_step)
]
train_op.extend(update_ops)
train_op = tf.group(*train_op)
predictions = {
'classes':
tf.concat([p['classes'] for p in tower_preds], axis=0),
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0)
'classes':
tf.concat([p['classes'] for p in tower_preds], axis=0),
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0)
}
stacked_labels = tf.concat(labels, axis=0)
metrics = {
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes'])
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes'])
}
return tf.estimator.EstimatorSpec(
......@@ -363,23 +360,21 @@ def input_fn(subset, num_shards):
label_shards = [tf.parallel_stack(x) for x in label_shards]
return feature_shards, label_shards
def create_experiment_fn(train_input, test_input, hooks):
def _experiment_fn(run_config, hparams):
estimator = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config,
model_dir=FLAGS.model_dir)
experiment = tf.contrib.learn.Experiment(
estimator,
train_input_fn=train_input,
eval_input_fn=test_input,
train_steps=FLAGS.train_steps)
# create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps):
def _experiment_fn(run_config, hparams):
del hparams # unused arg
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
return tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps
)
return _experiment_fn
experiment.extend_train_hooks(hooks)
return experiment
return _experiment_fn
def main(unused_argv):
# The env variable is on deprecation path, default is set to off.
......@@ -411,38 +406,17 @@ def main(unused_argv):
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config)
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
if FLAGS.run_experiment:
tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps), run_config=config)
else:
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, config=config)
tensors_to_log = {'learning_rate': 'learning_rate'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
print('Starting to train...')
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=[logging_hook])
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
train_input = functools.partial(input_fn, subset='train', num_shards=FLAGS.num_gpus)
test_input = functools.partial(input_fn, subset='eval', num_shards=FLAGS.num_gpus)
tensors_to_log = {'learning_rate': 'learning_rate'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook]
# run experiment
learn_runner.run(create_experiment_fn(train_input, test_input, hooks), run_config=config)
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment