Commit 7c460c90 authored by Toby Boyd's avatar Toby Boyd
Browse files

Merge branch 'cmlesupport' of https://github.com/elibixby/models

parents b14765c3 2164c8db
...@@ -34,8 +34,8 @@ data_batch_4 data_batch_5 readme.html test_batch ...@@ -34,8 +34,8 @@ data_batch_4 data_batch_5 readme.html test_batch
```shell ```shell
# This will generate a tf record for the training and test data available at the input_dir. # This will generate a tf record for the training and test data available at the input_dir.
# You can see more details in generate_cifar10_tf_records.py # You can see more details in generate_cifar10_tf_records.py
$ python generate_cifar10_tfrecords.py --input_dir=/prefix/to/downloaded/data/cifar-10-batches-py \ $ python generate_cifar10_tfrecords.py --input-dir=/prefix/to/downloaded/data/cifar-10-batches-py \
--output_dir=/prefix/to/downloaded/data/cifar-10-batches-py --output-dir=/prefix/to/downloaded/data/cifar-10-batches-py
``` ```
After running the command above, you should see the following new files in the output_dir. After running the command above, you should see the following new files in the output_dir.
...@@ -51,36 +51,60 @@ train.tfrecords validation.tfrecords eval.tfrecords ...@@ -51,36 +51,60 @@ train.tfrecords validation.tfrecords eval.tfrecords
``` ```
# Run the model on CPU only. After training, it runs the evaluation. # Run the model on CPU only. After training, it runs the evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-py \ $ python cifar10_main.py --data-dir=/prefix/to/downloaded/data/cifar-10-batches-py \
--model_dir=/tmp/cifar10 \ --job-dir=/tmp/cifar10 \
--is_cpu_ps=True \ --num-gpus=0 \
--num_gpus=0 \ --train-steps=1000
--train_steps=1000
# Run the model on 2 GPUs using CPU as parameter server. After training, it runs the evaluation. # Run the model on 2 GPUs using CPU as parameter server. After training, it runs the evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-py \ $ python cifar10_main.py --data-dir=/prefix/to/downloaded/data/cifar-10-batches-py \
--model_dir=/tmp/cifar10 \ --job-dir=/tmp/cifar10 \
--is_cpu_ps=True \ --num-gpus=2 \
--force_gpu_compatible=True \ --train-steps=1000
--num_gpus=2 \
--train_steps=1000
# Run the model on 2 GPUs using GPU as parameter server. # Run the model on 2 GPUs using GPU as parameter server.
# It will run an experiment, which for local setting basically means it will run stop training # It will run an experiment, which for local setting basically means it will run stop training
# a couple of times to perform evaluation. # a couple of times to perform evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-bin \ $ python cifar10_main.py --data-dir=/prefix/to/downloaded/data/cifar-10-batches-bin \
--model_dir=/tmp/cifar10 \ --job-dir=/tmp/cifar10 \
--is_cpu_ps=False \ --variable-strategy GPU \
--force_gpu_compatible=True \ --num-gpus=2 \
--num_gpus=2 \
--train_steps=1000
--run_experiment=True
# There are more command line flags to play with; check cifar10_main.py for details. # There are more command line flags to play with; check cifar10_main.py for details.
``` ```
## How to run on distributed mode ## How to run on distributed mode
### (Optional) Running on Google Cloud Machine Learning Engine
This example can be run on Google Cloud Machine Learning Engine (ML Engine), which will configure the environment and take care of running workers, parameters servers, and masters in a fault tolerant way.
To install the command line tool, and set up a project and billing, see the quickstart [here](https://cloud.google.com/ml-engine/docs/quickstarts/command-line).
You'll also need a Google Cloud Storage bucket for the data. If you followed the instructions above, you can just run:
```
MY_BUCKET=gs://<my-bucket-name>
gsutil cp -r cifar-10-batches-py $MY_BUCKET/
```
Then run the following command from the `tutorials/image` directory of this repository (the parent directory of this README):
```
gcloud ml-engine jobs submit training cifarmultigpu \
--runtime-version 1.2 \
--job-dir=$MY_BUCKET/model_dirs/cifarmultigpu \
--config cifar10_estimator/cmle_config.yaml \
--package-path cifar10_estimator/ \
--module-name cifar10_estimator.cifar10_main \
-- \
--data-dir=$MY_BUCKET/cifar-10-batches-py \
--num-gpus=4 \
--train-steps=1000
```
### Set TF_CONFIG ### Set TF_CONFIG
Considering that you already have multiple hosts configured, all you need is a `TF_CONFIG` Considering that you already have multiple hosts configured, all you need is a `TF_CONFIG`
...@@ -154,15 +178,12 @@ Once you have a `TF_CONFIG` configured properly on each host you're ready to run ...@@ -154,15 +178,12 @@ Once you have a `TF_CONFIG` configured properly on each host you're ready to run
# It will run evaluation a couple of times during training. # It will run evaluation a couple of times during training.
# The num_workers arugument is used only to update the learning rate correctly. # The num_workers arugument is used only to update the learning rate correctly.
# Make sure the model_dir is the same as defined on the TF_CONFIG. # Make sure the model_dir is the same as defined on the TF_CONFIG.
$ python cifar10_main.py --data_dir=gs://path/cifar-10-batches-py \ $ python cifar10_main.py --data-dir=gs://path/cifar-10-batches-py \
--model_dir=gs://path/model_dir/ \ --job-dir=gs://path/model_dir/ \
--is_cpu_ps=True \ --num-gpus=4 \
--force_gpu_compatible=True \ --train-steps=40000 \
--num_gpus=4 \ --sync \
--train_steps=40000 \ --num-workers=2
--sync=True \
--run_experiment=True \
--num_workers=2
``` ```
*Output:* *Output:*
...@@ -297,14 +318,11 @@ INFO:tensorflow:Saving dict for global step 1: accuracy = 0.0994, global_step = ...@@ -297,14 +318,11 @@ INFO:tensorflow:Saving dict for global step 1: accuracy = 0.0994, global_step =
# Runs an Experiment in sync mode on 4 GPUs using CPU as parameter server for 40000 steps. # Runs an Experiment in sync mode on 4 GPUs using CPU as parameter server for 40000 steps.
# It will run evaluation a couple of times during training. # It will run evaluation a couple of times during training.
# Make sure the model_dir is the same as defined on the TF_CONFIG. # Make sure the model_dir is the same as defined on the TF_CONFIG.
$ python cifar10_main.py --data_dir=gs://path/cifar-10-batches-py \ $ python cifar10_main.py --data-dir=gs://path/cifar-10-batches-py \
--model_dir=gs://path/model_dir/ \ --job-dir=gs://path/model_dir/ \
--is_cpu_ps=True \ --num-gpus=4 \
--force_gpu_compatible=True \ --train-steps=40000 \
--num_gpus=4 \ --sync
--train_steps=40000 \
--sync=True
--run_experiment=True
``` ```
*Output:* *Output:*
...@@ -413,7 +431,7 @@ INFO:tensorflow:loss = 27.8453, step = 179 (18.893 sec) ...@@ -413,7 +431,7 @@ INFO:tensorflow:loss = 27.8453, step = 179 (18.893 sec)
```shell ```shell
# Run this on ps: # Run this on ps:
# The ps will not do training so most of the arguments won't affect the execution # The ps will not do training so most of the arguments won't affect the execution
$ python cifar10_main.py --run_experiment=True --model_dir=gs://path/model_dir/ $ python cifar10_main.py --job-dir=gs://path/model_dir/
# There are more command line flags to play with; check cifar10_main.py for details. # There are more command line flags to play with; check cifar10_main.py for details.
``` ```
...@@ -446,12 +464,12 @@ You'll see something similar to this if you "point" TensorBoard to the `model_di ...@@ -446,12 +464,12 @@ You'll see something similar to this if you "point" TensorBoard to the `model_di
# Check TensorBoard during training or after it. # Check TensorBoard during training or after it.
# Just point TensorBoard to the model_dir you chose on the previous step # Just point TensorBoard to the model_dir you chose on the previous step
# by default the model_dir is "sentiment_analysis_output" # by default the model_dir is "sentiment_analysis_output"
$ tensorboard --log_dir="sentiment_analysis_output" $ tensorboard --log-dir="sentiment_analysis_output"
``` ```
## Warnings ## Warnings
When runninng `cifar10_main.py` with `--sync=True` argument you may see an error similar to: When runninng `cifar10_main.py` with `--sync` argument you may see an error similar to:
```python ```python
File "cifar10_main.py", line 538, in <module> File "cifar10_main.py", line 538, in <module>
......
...@@ -25,278 +25,121 @@ http://www.cs.toronto.edu/~kriz/cifar.html ...@@ -25,278 +25,121 @@ http://www.cs.toronto.edu/~kriz/cifar.html
""" """
from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import functools import functools
import operator import itertools
import os import os
import six
import cifar10
import cifar10_model
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('data_dir', '',
'The directory where the CIFAR-10 input data is stored.')
tf.flags.DEFINE_string('model_dir', '',
'The directory where the model will be stored.')
tf.flags.DEFINE_boolean('is_cpu_ps', True,
'If using CPU as the parameter server.')
tf.flags.DEFINE_integer('num_gpus', 1,
'The number of gpus used. Uses only CPU if set to 0.')
tf.flags.DEFINE_integer('num_layers', 44, 'The number of layers of the model.')
tf.flags.DEFINE_integer('train_steps', 80000,
'The number of steps to use for training.')
tf.flags.DEFINE_integer('train_batch_size', 128, 'Batch size for training.')
tf.flags.DEFINE_integer('eval_batch_size', 100, 'Batch size for validation.')
tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
tf.flags.DEFINE_float('weight_decay', 2e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_float('learning_rate', 0.1,
'This is the inital learning rate value.'
' The learning rate will decrease during training.'
' For more details check the model_fn implementation'
' in this file.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True, import cifar10
'If doing image distortion for training.') import cifar10_model
import cifar10_utils
tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
' otherwise will run training and evaluation'
' using the estimator interface.'
' Experiments perform training on several workers in'
' parallel, in other words experiments know how to'
' invoke train and eval in a sensible fashion for'
' distributed training.')
tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment'
' will run on sync mode.')
tf.flags.DEFINE_integer('num_workers', 1, 'Number of workers.')
# Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1,
'Number of threads to use for intra-op parallelism.'
' If set to 0, the system will pick an appropriate number.'
' The default is 1 since in this example CPU only handles'
' the input pipeline and gradient aggregation (when'
' --is_cpu_ps). Ops that could potentially benefit'
' from intra-op parallelism are scheduled to run on GPUs.')
tf.flags.DEFINE_integer('num_inter_threads', 0,
'Number of threads to use for inter-op'
' parallelism. If set to 0, the system will pick'
' an appropriate number.')
tf.flags.DEFINE_boolean('force_gpu_compatible', False,
'Whether to enable force_gpu_compatible in'
' GPU_Options. Check'
' tensorflow/core/protobuf/config.proto#L69'
' for details.')
# Debugging flags
tf.flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.')
class ExamplesPerSecondHook(session_run_hook.SessionRunHook): tf.logging.set_verbosity(tf.logging.INFO)
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__( def get_model_fn(num_gpus, variable_strategy, num_workers):
self, def _resnet_model_fn(features, labels, mode, params):
batch_size, """Resnet model body.
every_n_steps=100,
every_n_secs=None,):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError('exactly one of every_n_steps'
' and every_n_secs should be provided.')
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
'Global step should be created to use StepCounterHook.')
def before_run(self, run_context): # pylint: disable=unused-argument
return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
_ = run_context
global_step = run_values.results
if self._timer.should_trigger_for_step(global_step):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
steps_per_sec = elapsed_steps / elapsed_time
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
current_examples_per_sec = steps_per_sec * self._batch_size
# Average examples/sec followed by current examples/sec
logging.info('%s: %g (%g), step = %g', 'Average examples/sec',
average_examples_per_sec, current_examples_per_sec,
self._total_steps)
class GpuParamServerDeviceSetter(object):
"""Used with tf.device() to place variables on the least loaded GPU.
A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
placed on the least loaded gpu. All other Ops, which will be the computation
Ops, will be placed on the worker_device.
"""
def __init__(self, worker_device, ps_devices): Support single host, one or more GPU training. Parameter distribution can
"""Initializer for GpuParamServerDeviceSetter. be either one of the following scheme.
1. CPU is the parameter server and manages gradient updates.
2. Parameters are distributed evenly across all GPUs, and the first GPU
manages gradient updates.
Args: Args:
worker_device: the device to use for computation Ops. features: a list of tensors, one for each tower
ps_devices: a list of devices to use for Variable Ops. Each variable is labels: a list of tensors, one for each tower
assigned to the least loaded device. mode: ModeKeys.TRAIN or EVAL
params: Dictionary of Hyperparameters suitable for tuning
Returns:
A EstimatorSpec object.
""" """
self.ps_devices = ps_devices is_training = (mode == tf.estimator.ModeKeys.TRAIN)
self.worker_device = worker_device weight_decay = params['weight_decay']
self.ps_sizes = [0] * len(self.ps_devices) momentum = params['momentum']
def __call__(self, op): tower_features = features
if op.device: tower_labels = labels
return op.device tower_losses = []
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']: tower_gradvars = []
return self.worker_device tower_preds = []
# Gets the least loaded ps_device if num_gpus != 0:
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1)) for i in range(num_gpus):
device_name = self.ps_devices[device_index] worker_device = '/gpu:{}'.format(i)
var_size = op.outputs[0].get_shape().num_elements() if variable_strategy == 'CPU':
self.ps_sizes[device_index] += var_size device_setter = cifar10_utils.local_device_setter(
worker_device=worker_device)
return device_name elif variable_strategy == 'GPU':
device_setter = cifar10_utils.local_device_setter(
ps_device_type='gpu',
def _create_device_setter(is_cpu_ps, worker, num_gpus): worker_device=worker_device,
"""Create device setter object.""" ps_strategy=tf.contrib.training.GreedyLoadBalancingStrategy(
if is_cpu_ps: num_gpus,
# tf.train.replica_device_setter supports placing variables on the CPU, all tf.contrib.training.byte_size_load_fn
# on one GPU, or on ps_servers defined in a cluster_spec. )
return tf.train.replica_device_setter( )
worker_device=worker, ps_device='/cpu:0', ps_tasks=1) with tf.variable_scope('resnet', reuse=bool(i != 0)):
else: with tf.name_scope('tower_%d' % i) as name_scope:
gpus = ['/gpu:%d' % i for i in range(num_gpus)] with tf.device(device_setter):
return GpuParamServerDeviceSetter(worker, gpus) loss, gradvars, preds = _tower_fn(
is_training,
weight_decay,
def _resnet_model_fn(features, labels, mode): tower_features[i],
"""Resnet model body. tower_labels[i],
False,
Support single host, one or more GPU training. Parameter distribution can be params['num_layers'],
either one of the following scheme. params['batch_norm_decay'],
1. CPU is the parameter server and manages gradient updates. params['batch_norm_epsilon'])
2. Parameters are distributed evenly across all GPUs, and the first GPU tower_losses.append(loss)
manages gradient updates. tower_gradvars.append(gradvars)
tower_preds.append(preds)
Args: if i == 0:
features: a list of tensors, one for each tower # Only trigger batch_norm moving mean and variance update from
labels: a list of tensors, one for each tower # the 1st tower. Ideally, we should grab the updates from all
mode: ModeKeys.TRAIN or EVAL # towers but these stats accumulate extremely fast so we can
Returns: # ignore the other stats from the other towers without
A EstimatorSpec object. # significant detriment.
""" update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
is_training = (mode == tf.estimator.ModeKeys.TRAIN) name_scope)
is_cpu_ps = FLAGS.is_cpu_ps else:
num_gpus = FLAGS.num_gpus with tf.variable_scope('resnet'), tf.device('/cpu:0'):
weight_decay = FLAGS.weight_decay with tf.name_scope('tower_cpu') as name_scope:
momentum = FLAGS.momentum loss, gradvars, preds = _tower_fn(
is_training,
tower_features = features weight_decay,
tower_labels = labels tower_features[0],
tower_losses = [] tower_labels[0],
tower_gradvars = [] True,
tower_preds = [] params['num_layers'],
params['batch_norm_decay'],
if num_gpus != 0: params['batch_norm_epsilon'])
for i in range(num_gpus): tower_losses.append(loss)
worker = '/gpu:%d' % i tower_gradvars.append(gradvars)
device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus) tower_preds.append(preds)
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
with tf.device(device_setter):
_tower_fn(is_training, weight_decay, tower_features[i], # Now compute global loss and gradients.
tower_labels[i], tower_losses, tower_gradvars, gradvars = []
tower_preds, False)
if i == 0:
# Only trigger batch_norm moving mean and variance update from the
# 1st tower. Ideally, we should grab the updates from all towers
# but these stats accumulate extremely fast so we can ignore the
# other stats from the other towers without significant detriment.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
name_scope)
else:
with tf.variable_scope('resnet'), tf.device('/cpu:0'):
with tf.name_scope('tower_cpu') as name_scope:
_tower_fn(is_training, weight_decay, tower_features[0], tower_labels[0],
tower_losses, tower_gradvars, tower_preds, True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, name_scope)
# Now compute global loss and gradients.
gradvars = []
# parameter server here isn't necessarily one server storing the model params.
# (For gpu-as-ps case, model params are distributed evenly across all gpus.)
# It's the server that runs the ops to apply global gradient updates.
ps_device = '/cpu:0' if is_cpu_ps else '/gpu:0'
with tf.device(ps_device):
with tf.name_scope('gradient_averaging'): with tf.name_scope('gradient_averaging'):
loss = tf.reduce_mean(tower_losses, name='loss') all_grads = {}
for zipped_gradvars in zip(*tower_gradvars): for grad, var in itertools.chain(*tower_gradvars):
# Averaging one var's gradients computed from multiple towers if grad is not None:
var = zipped_gradvars[0][1] all_grads.setdefault(var, []).append(grad)
grads = [gv[0] for gv in zipped_gradvars] for var, grads in six.iteritems(all_grads):
# Average gradients on the same device as the variables
# to which they apply.
with tf.device(var.device): with tf.device(var.device):
if len(grads) == 1: if len(grads) == 1:
avg_grad = grads[0] avg_grad = grads[0]
...@@ -304,63 +147,75 @@ def _resnet_model_fn(features, labels, mode): ...@@ -304,63 +147,75 @@ def _resnet_model_fn(features, labels, mode):
avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads)) avg_grad = tf.multiply(tf.add_n(grads), 1. / len(grads))
gradvars.append((avg_grad, var)) gradvars.append((avg_grad, var))
# Suggested learning rate scheduling from
# https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155 # Device that runs the ops to apply global gradient updates.
# users could apply other scheduling. consolidation_device = '/gpu:0' if variable_strategy == 'GPU' else '/cpu:0'
num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch( with tf.device(consolidation_device):
'train') // (FLAGS.train_batch_size * FLAGS.num_workers) # Suggested learning rate scheduling from
boundaries = [ # https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
num_batches_per_epoch * x # users could apply other scheduling.
for x in np.array([82, 123, 300], dtype=np.int64) num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
] 'train') // (params['train_batch_size'] * num_workers)
staged_lr = [FLAGS.learning_rate * x for x in [1, 0.1, 0.01, 0.002]] boundaries = [
num_batches_per_epoch * x
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(), for x in np.array([82, 123, 300], dtype=np.int64)
boundaries, staged_lr) ]
# Create a nicely-named tensor for logging staged_lr = [params['learning_rate'] * x for x in [1, 0.1, 0.01, 0.002]]
learning_rate = tf.identity(learning_rate, name='learning_rate')
learning_rate = tf.train.piecewise_constant(tf.train.get_global_step(),
optimizer = tf.train.MomentumOptimizer( boundaries, staged_lr)
learning_rate=learning_rate, momentum=momentum) # Create a nicely-named tensor for logging
learning_rate = tf.identity(learning_rate, name='learning_rate')
chief_hooks = []
if FLAGS.sync: optimizer = tf.train.MomentumOptimizer(
optimizer = tf.train.SyncReplicasOptimizer( learning_rate=learning_rate, momentum=momentum)
optimizer,
replicas_to_aggregate=FLAGS.num_workers) chief_hooks = []
sync_replicas_hook = optimizer.make_session_run_hook(True) if params['sync']:
chief_hooks.append(sync_replicas_hook) optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
# Create single grouped train op replicas_to_aggregate=num_workers)
train_op = [ sync_replicas_hook = optimizer.make_session_run_hook(True)
optimizer.apply_gradients( chief_hooks.append(sync_replicas_hook)
gradvars, global_step=tf.train.get_global_step())
] # Create single grouped train op
train_op.extend(update_ops) train_op = [
train_op = tf.group(*train_op) optimizer.apply_gradients(
gradvars, global_step=tf.train.get_global_step())
predictions = { ]
'classes': train_op.extend(update_ops)
tf.concat([p['classes'] for p in tower_preds], axis=0), train_op = tf.group(*train_op)
'probabilities':
tf.concat([p['probabilities'] for p in tower_preds], axis=0) predictions = {
} 'classes':
stacked_labels = tf.concat(labels, axis=0) tf.concat([p['classes'] for p in tower_preds], axis=0),
metrics = { 'probabilities':
'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes']) tf.concat([p['probabilities'] for p in tower_preds], axis=0)
} }
stacked_labels = tf.concat(labels, axis=0)
return tf.estimator.EstimatorSpec( metrics = {
mode=mode, 'accuracy': tf.metrics.accuracy(stacked_labels, predictions['classes'])
predictions=predictions, }
loss=loss, loss = tf.reduce_mean(tower_losses, name='loss')
train_op=train_op,
training_chief_hooks=chief_hooks, return tf.estimator.EstimatorSpec(
eval_metric_ops=metrics) mode=mode,
predictions=predictions,
loss=loss,
def _tower_fn(is_training, weight_decay, feature, label, tower_losses, train_op=train_op,
tower_gradvars, tower_preds, is_cpu): training_chief_hooks=chief_hooks,
eval_metric_ops=metrics)
return _resnet_model_fn
def _tower_fn(is_training,
weight_decay,
feature,
label,
is_cpu,
num_layers,
batch_norm_decay,
batch_norm_epsilon):
"""Build computation tower for each device (CPU or GPU). """Build computation tower for each device (CPU or GPU).
Args: Args:
...@@ -375,13 +230,15 @@ def _tower_fn(is_training, weight_decay, feature, label, tower_losses, ...@@ -375,13 +230,15 @@ def _tower_fn(is_training, weight_decay, feature, label, tower_losses,
""" """
data_format = 'channels_last' if is_cpu else 'channels_first' data_format = 'channels_last' if is_cpu else 'channels_first'
model = cifar10_model.ResNetCifar10( model = cifar10_model.ResNetCifar10(
FLAGS.num_layers, is_training=is_training, data_format=data_format) num_layers,
batch_norm_decay=batch_norm_decay,
batch_norm_epsilon=batch_norm_epsilon,
is_training=is_training, data_format=data_format)
logits = model.forward_pass(feature, input_data_format='channels_last') logits = model.forward_pass(feature, input_data_format='channels_last')
tower_pred = { tower_pred = {
'classes': tf.argmax(input=logits, axis=1), 'classes': tf.argmax(input=logits, axis=1),
'probabilities': tf.nn.softmax(logits) 'probabilities': tf.nn.softmax(logits)
} }
tower_preds.append(tower_pred)
tower_loss = tf.losses.sparse_softmax_cross_entropy( tower_loss = tf.losses.sparse_softmax_cross_entropy(
logits=logits, labels=label) logits=logits, labels=label)
...@@ -390,13 +247,14 @@ def _tower_fn(is_training, weight_decay, feature, label, tower_losses, ...@@ -390,13 +247,14 @@ def _tower_fn(is_training, weight_decay, feature, label, tower_losses,
model_params = tf.trainable_variables() model_params = tf.trainable_variables()
tower_loss += weight_decay * tf.add_n( tower_loss += weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in model_params]) [tf.nn.l2_loss(v) for v in model_params])
tower_losses.append(tower_loss)
tower_grad = tf.gradients(tower_loss, model_params) tower_grad = tf.gradients(tower_loss, model_params)
tower_gradvars.append(zip(tower_grad, model_params))
return tower_loss, zip(tower_grad, model_params), tower_pred
def input_fn(subset, num_shards): def input_fn(data_dir, subset, num_shards, batch_size,
use_distortion_for_training=True):
"""Create input graph for model. """Create input graph for model.
Args: Args:
...@@ -405,16 +263,9 @@ def input_fn(subset, num_shards): ...@@ -405,16 +263,9 @@ def input_fn(subset, num_shards):
Returns: Returns:
two lists of tensors for features and labels, each of num_shards length. two lists of tensors for features and labels, each of num_shards length.
""" """
if subset == 'train':
batch_size = FLAGS.train_batch_size
elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size
else:
raise ValueError('Subset must be one of \'train\''
', \'validate\' and \'eval\'')
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
use_distortion = subset == 'train' and FLAGS.use_distortion_for_training use_distortion = subset == 'train' and use_distortion_for_training
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion) dataset = cifar10.Cifar10DataSet(data_dir, subset, use_distortion)
image_batch, label_batch = dataset.make_batch(batch_size) image_batch, label_batch = dataset.make_batch(batch_size)
if num_shards <= 1: if num_shards <= 1:
# No GPU available or only 1 GPU. # No GPU available or only 1 GPU.
...@@ -438,20 +289,73 @@ def input_fn(subset, num_shards): ...@@ -438,20 +289,73 @@ def input_fn(subset, num_shards):
# create experiment # create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps, def get_experiment_fn(data_dir, num_gpus, is_gpu_ps,
train_hooks): use_distortion_for_training=True):
"""Returns an Experiment function. """Returns an Experiment function.
Experiments perform training on several workers in parallel, Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training. fashion for distributed training. Arguments passed directly to this
function are not tunable, all other arguments should be passed within
tf.HParams, passed to the enclosed function.
Args:
data_dir: str. Location of the data for input_fns.
num_gpus: int. Number of GPUs on each worker.
is_gpu_ps: bool. If true, average gradients on GPUs.
use_distortion_for_training: bool. See cifar10.Cifar10DataSet.
Returns:
A function (tf.estimator.RunConfig, tf.contrib.training.HParams) ->
tf.contrib.learn.Experiment.
Suitable for use by tf.contrib.learn.learn_runner, which will run various
methods on Experiment (train, evaluate) based on information
about the current runner in `run_config`.
""" """
def _experiment_fn(run_config, hparams): def _experiment_fn(run_config, hparams):
"""Returns an Experiment.""" """Returns an Experiment."""
del hparams # Unused arg.
# Create estimator. # Create estimator.
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn, train_input_fn = functools.partial(
config=run_config) input_fn,
data_dir,
subset='train',
num_shards=num_gpus,
batch_size=hparams.train_batch_size,
use_distortion_for_training=use_distortion_for_training
)
eval_input_fn = functools.partial(
input_fn,
data_dir,
subset='eval',
batch_size=hparams.eval_batch_size,
num_shards=num_gpus
)
num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval')
if num_eval_examples % hparams.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size')
train_steps = hparams.train_steps
eval_steps = num_eval_examples // hparams.eval_batch_size
examples_sec_hook = cifar10_utils.ExamplesPerSecondHook(
hparams.train_batch_size, every_n_steps=10)
tensors_to_log = {'learning_rate': 'learning_rate',
'loss': 'loss'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
hooks = [logging_hook, examples_sec_hook]
classifier = tf.estimator.Estimator(
model_fn=get_model_fn(
num_gpus, is_gpu_ps, run_config.num_worker_replicas or 1),
config=run_config,
params=vars(hparams)
)
# Create experiment. # Create experiment.
experiment = tf.contrib.learn.Experiment( experiment = tf.contrib.learn.Experiment(
classifier, classifier,
...@@ -460,86 +364,186 @@ def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps, ...@@ -460,86 +364,186 @@ def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps,
train_steps=train_steps, train_steps=train_steps,
eval_steps=eval_steps) eval_steps=eval_steps)
# Adding hooks to be used by the estimator on training mode. # Adding hooks to be used by the estimator on training mode.
experiment.extend_train_hooks(train_hooks) experiment.extend_train_hooks(hooks)
return experiment return experiment
return _experiment_fn return _experiment_fn
def main(unused_argv): def main(job_dir,
data_dir,
num_gpus,
variable_strategy,
use_distortion_for_training,
log_device_placement,
num_intra_threads,
**hparams):
# The env variable is on deprecation path, default is set to off. # The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_SYNC_ON_FINISH'] = '0'
if FLAGS.num_gpus < 0: # Session configuration.
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
intra_op_parallelism_threads=num_intra_threads,
gpu_options=tf.GPUOptions(
force_gpu_compatible=True
)
)
config = tf.contrib.learn.RunConfig(
session_config=sess_config,
model_dir=job_dir)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(
data_dir,
num_gpus,
variable_strategy,
use_distortion_for_training
),
run_config=config,
hparams=tf.contrib.training.HParams(**hparams)
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
required=True,
help='The directory where the CIFAR-10 input data is stored.'
)
parser.add_argument(
'--job-dir',
type=str,
required=True,
help='The directory where the model will be stored.'
)
parser.add_argument(
'--variable_strategy',
choices=['CPU', 'GPU'],
type=str,
default='CPU',
help='Where to locate variable operations'
)
parser.add_argument(
'--num-gpus',
type=int,
default=1,
help='The number of gpus used. Uses only CPU if set to 0.'
)
parser.add_argument(
'--num-layers',
type=int,
default=44,
help='The number of layers of the model.'
)
parser.add_argument(
'--train-steps',
type=int,
default=80000,
help='The number of steps to use for training.'
)
parser.add_argument(
'--train-batch-size',
type=int,
default=128,
help='Batch size for training.'
)
parser.add_argument(
'--eval-batch-size',
type=int,
default=100,
help='Batch size for validation.'
)
parser.add_argument(
'--momentum',
type=float,
default=0.9,
help='Momentum for MomentumOptimizer.'
)
parser.add_argument(
'--weight-decay',
type=float,
default=2e-4,
help='Weight decay for convolutions.'
)
parser.add_argument(
'--learning-rate',
type=float,
default=0.1,
help="""\
This is the inital learning rate value. The learning rate will decrease
during training. For more details check the model_fn implementation in
this file.\
"""
)
parser.add_argument(
'--use-distortion-for-training',
type=bool,
default=True,
help='If doing image distortion for training.'
)
parser.add_argument(
'--sync',
action='store_true',
default=False,
help="""\
If present when running in a distributed environment will run on sync mode.\
"""
)
parser.add_argument(
'--num-intra-threads',
type=int,
default=1,
help="""\
Number of threads to use for intra-op parallelism. If set to 0, the
system will pick an appropriate number. The default is 1 since in this
example CPU only handles the input pipeline and gradient aggregation
(when --is-cpu-ps). Ops that could potentially benefit from intra-op
parallelism are scheduled to run on GPUs.\
"""
)
parser.add_argument(
'--num-inter-threads',
type=int,
default=0,
help="""\
Number of threads to use for inter-op parallelism. If set to 0, the
system will pick an appropriate number.\
"""
)
parser.add_argument(
'--log-device-placement',
action='store_true',
default=False,
help='Whether to log device placement.'
)
parser.add_argument(
'--batch_norm_decay',
type=float,
default=0.997,
help='Decay for batch norm.'
)
parser.add_argument(
'--batch_norm_epsilon',
type=float,
default=1e-5,
help='Epsilon for batch norm.'
)
args = parser.parse_args()
if args.num_gpus < 0:
raise ValueError( raise ValueError(
'Invalid GPU count: \"num_gpus\" must be 0 or a positive integer.') 'Invalid GPU count: \"num_gpus\" must be 0 or a positive integer.')
if FLAGS.num_gpus == 0 and not FLAGS.is_cpu_ps: if args.num_gpus == 0 and args.variable_strategy == 'GPU':
raise ValueError( raise ValueError(
'No GPU available for use, must use CPU as parameter server.') 'No GPU available for use, must use CPU to average gradients.')
if (FLAGS.num_layers - 2) % 6 != 0: if (args.num_layers - 2) % 6 != 0:
raise ValueError('Invalid num_layers parameter.') raise ValueError('Invalid num_layers parameter.')
if FLAGS.num_gpus != 0 and FLAGS.train_batch_size % FLAGS.num_gpus != 0: if args.num_gpus != 0 and args.train_batch_size % args.num_gpus != 0:
raise ValueError('train_batch_size must be multiple of num_gpus.') raise ValueError('train_batch_size must be multiple of num_gpus.')
if FLAGS.num_gpus != 0 and FLAGS.eval_batch_size % FLAGS.num_gpus != 0: if args.num_gpus != 0 and args.eval_batch_size % args.num_gpus != 0:
raise ValueError('eval_batch_size must be multiple of num_gpus.') raise ValueError('eval_batch_size must be multiple of num_gpus.')
num_eval_examples = cifar10.Cifar10DataSet.num_examples_per_epoch('eval') main(**vars(args))
if num_eval_examples % FLAGS.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size')
train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
# Session configuration.
sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
# Hooks that add extra logging that is useful to see the loss more often in
# the console as well as examples per second.
tensors_to_log = {'learning_rate': 'learning_rate',
'loss': 'gradient_averaging/loss'}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
examples_sec_hook = ExamplesPerSecondHook(
FLAGS.train_batch_size, every_n_steps=10)
hooks = [logging_hook, examples_sec_hook]
if FLAGS.run_experiment:
config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
config = config.replace(session_config=sess_config)
tf.contrib.learn.learn_runner.run(
get_experiment_fn(train_input_fn, eval_input_fn,
train_steps, eval_steps,
hooks), run_config=config)
else:
config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
print('Starting to train...')
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=hooks)
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
if __name__ == '__main__':
tf.app.run()
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Model class for Cifar10 Dataset.""" """Model class for Cifar10 Dataset."""
from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
...@@ -25,8 +24,18 @@ import model_base ...@@ -25,8 +24,18 @@ import model_base
class ResNetCifar10(model_base.ResNet): class ResNetCifar10(model_base.ResNet):
"""Cifar10 model with ResNetV1 and basic residual block.""" """Cifar10 model with ResNetV1 and basic residual block."""
def __init__(self, num_layers, is_training, data_format='channels_first'): def __init__(self,
super(ResNetCifar10, self).__init__(is_training, data_format) num_layers,
is_training,
batch_norm_decay,
batch_norm_epsilon,
data_format='channels_first'):
super(ResNetCifar10, self).__init__(
is_training,
data_format,
batch_norm_decay,
batch_norm_epsilon
)
self.n = (num_layers - 2) // 6 self.n = (num_layers - 2) // 6
# Add one in case label starts with 1. No impact if label starts with 0. # Add one in case label starts with 1. No impact if label starts with 0.
self.num_classes = 10 + 1 self.num_classes = 10 + 1
......
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training import device_setter
class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(
self,
batch_size,
every_n_steps=100,
every_n_secs=None,):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError('exactly one of every_n_steps'
' and every_n_secs should be provided.')
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
'Global step should be created to use StepCounterHook.')
def before_run(self, run_context): # pylint: disable=unused-argument
return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
_ = run_context
global_step = run_values.results
if self._timer.should_trigger_for_step(global_step):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
steps_per_sec = elapsed_steps / elapsed_time
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
current_examples_per_sec = steps_per_sec * self._batch_size
# Average examples/sec followed by current examples/sec
logging.info('%s: %g (%g), step = %g', 'Average examples/sec',
average_examples_per_sec, current_examples_per_sec,
self._total_steps)
def local_device_setter(num_devices=1,
ps_device_type='cpu',
worker_device='/cpu:0',
ps_ops=None,
ps_strategy=None):
if ps_ops == None:
ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
if ps_strategy is None:
ps_strategy = device_setter._RoundRobinStrategy(num_devices)
if not six.callable(ps_strategy):
raise TypeError("ps_strategy must be callable")
def _local_device_chooser(op):
current_device = pydev.DeviceSpec.from_string(op.device or "")
node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
if node_def.op in ps_ops:
ps_device_spec = pydev.DeviceSpec.from_string(
'/{}:{}'.format(ps_device_type, ps_strategy(op)))
ps_device_spec.merge_from(current_device)
return ps_device_spec.to_string()
else:
worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
worker_device_spec.merge_from(current_device)
return worker_device_spec.to_string()
return _local_device_chooser
trainingInput:
scaleTier: CUSTOM
masterType: complex_model_m_gpu
workerType: complex_model_m_gpu
parameterServerType: complex_model_m
workerCount: 1
...@@ -22,19 +22,11 @@ from __future__ import absolute_import ...@@ -22,19 +22,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import cPickle import cPickle
import os import os
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('input_dir', '', import tensorflow as tf
'Directory where CIFAR10 data is located.')
tf.flags.DEFINE_string('output_dir', '',
'Directory where TFRecords will be saved.'
'The TFRecords will have the same name as'
' the CIFAR10 inputs + .tfrecords.')
def _int64_feature(value): def _int64_feature(value):
...@@ -55,7 +47,7 @@ def _get_file_names(): ...@@ -55,7 +47,7 @@ def _get_file_names():
def read_pickle_from_file(filename): def read_pickle_from_file(filename):
with open(filename, 'r') as f: with tf.gfile.Open(filename, 'r') as f:
data_dict = cPickle.load(f) data_dict = cPickle.load(f)
return data_dict return data_dict
...@@ -63,34 +55,49 @@ def read_pickle_from_file(filename): ...@@ -63,34 +55,49 @@ def read_pickle_from_file(filename):
def convert_to_tfrecord(input_files, output_file): def convert_to_tfrecord(input_files, output_file):
"""Converts a file to tfrecords.""" """Converts a file to tfrecords."""
print('Generating %s' % output_file) print('Generating %s' % output_file)
record_writer = tf.python_io.TFRecordWriter(output_file) with tf.python_io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
for input_file in input_files: data_dict = read_pickle_from_file(input_file)
data_dict = read_pickle_from_file(input_file) data = data_dict['data']
data = data_dict['data'] labels = data_dict['labels']
labels = data_dict['labels']
num_entries_in_batch = len(labels)
num_entries_in_batch = len(labels) for i in range(num_entries_in_batch):
for i in range(num_entries_in_batch): example = tf.train.Example(
example = tf.train.Example( features=tf.train.Features(feature={
features=tf.train.Features(feature={ 'image': _bytes_feature(data[i].tobytes()),
'image': _bytes_feature(data[i].tobytes()), 'label': _int64_feature(labels[i])
'label': _int64_feature(labels[i]) }))
})) record_writer.write(example.SerializeToString())
record_writer.write(example.SerializeToString())
record_writer.close()
def main(input_dir, output_dir):
def main(unused_argv):
file_names = _get_file_names() file_names = _get_file_names()
for mode, files in file_names.items(): for mode, files in file_names.items():
input_files = [ input_files = [
os.path.join(FLAGS.input_dir, f) for f in files] os.path.join(input_dir, f) for f in files]
output_file = os.path.join(FLAGS.output_dir, mode + '.tfrecords') output_file = os.path.join(output_dir, mode + '.tfrecords')
# Convert to Examples and write the result to TFRecords. # Convert to Examples and write the result to TFRecords.
convert_to_tfrecord(input_files, output_file) convert_to_tfrecord(input_files, output_file)
print('Done!') print('Done!')
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run(main) parser = argparse.ArgumentParser()
parser.add_argument(
'--input-dir',
type=str,
default='',
help='Directory where CIFAR10 data is located.'
)
parser.add_argument(
'--output-dir',
type=str,
default='',
help="""\
Directory where TFRecords will be saved.The TFRecords will have the same
name as the CIFAR10 inputs + .tfrecords.\
"""
)
args = parser.parse_args()
main(args.input_dir, args.output_dir)
...@@ -25,16 +25,11 @@ from __future__ import print_function ...@@ -25,16 +25,11 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_float('batch_norm_decay', 0.997, 'Decay for batch norm.')
tf.flags.DEFINE_float('batch_norm_epsilon', 1e-5, 'Epsilon for batch norm.')
class ResNet(object): class ResNet(object):
"""ResNet model.""" """ResNet model."""
def __init__(self, is_training, data_format): def __init__(self, is_training, data_format, batch_norm_decay, batch_norm_epsilon):
"""ResNet constructor. """ResNet constructor.
Args: Args:
...@@ -42,6 +37,8 @@ class ResNet(object): ...@@ -42,6 +37,8 @@ class ResNet(object):
data_format: the data_format used during computation. data_format: the data_format used during computation.
one of 'channels_first' or 'channels_last'. one of 'channels_first' or 'channels_last'.
""" """
self._batch_norm_decay = batch_norm_decay
self._batch_norm_epsilon = batch_norm_epsilon
self._is_training = is_training self._is_training = is_training
assert data_format in ('channels_first', 'channels_last') assert data_format in ('channels_first', 'channels_last')
self._data_format = data_format self._data_format = data_format
...@@ -185,10 +182,10 @@ class ResNet(object): ...@@ -185,10 +182,10 @@ class ResNet(object):
data_format = 'NHWC' data_format = 'NHWC'
return tf.contrib.layers.batch_norm( return tf.contrib.layers.batch_norm(
x, x,
decay=FLAGS.batch_norm_decay, decay=self._batch_norm_decay,
center=True, center=True,
scale=True, scale=True,
epsilon=FLAGS.batch_norm_epsilon, epsilon=self._batch_norm_epsilon,
is_training=self._is_training, is_training=self._is_training,
fused=True, fused=True,
data_format=data_format) data_format=data_format)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment