Unverified Commit a41f00ac authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Multi-GPU Resnet (#3472)

* Refactor and use class-based model for Resnet

* Linting

* Cleanup

* Cleanup in lieu of MNIST changes

* Moving params to init function

* Making learning_rate function

* Testing remotely...

* Testing remotely...

* Testing remotely...

* Testing remotely...

* Testing remotely...

* Respond to CR pt 1

* Respond to CR pt 2

* Respond to CR pt 3

* Respond to CR pt 4

* Adding batch norm vars in

* Exclude batch norm vars as the default

* Fixing CIFAR-10 naming

* Adding multi-GPU code

* Git rewind

* Git add file

* Manual revert

* Fixing tests

* Allowing input thread specification

* Adding comments

* Allowing input thread specification

* Merging resnet files

* Merging resnet files

* Refactoring input methods to allow for Reed's improvements

* Adding comments

* Changing arg name

* Removing contrib shuffle_and_repeat

* Removing contrib shuffle_and_repeat

* Removing contrib shuffle_and_repeat

* Debugging

* Removing with dependency on update_op

* Updating comments.

* Returning dataset directly

* Adding newline

* Adding num_gpus flag

* Refining preprocessing, part 1

* Refinements to preprocessing resulting from multi-GPU tests

* Reviving one-hot labels

* Reviving one-hot labels

* Fixing label shapes

* Removing epoch leftovers

* Adding random flip back in

* Reverting unnecessary linting of test file

* Respond to CR

* Respond to CR

* Respond to CR

* Remove conversion to float

* Remove conversion to float- comment

* Making means full-scale

* Pulling data.take under multi_gpu flag

* Pulling data.take under multi_gpu flag- Cifar

* Pulling data.take under multi_gpu flag- Cifar

* Removing num_gpus

* Removing num_gpus
parent 4cdd9721
...@@ -103,7 +103,7 @@ def preprocess_image(image, is_training): ...@@ -103,7 +103,7 @@ def preprocess_image(image, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1): num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -114,6 +114,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -114,6 +114,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls: The number of records that are processed in parallel. num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores. sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -121,8 +124,11 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -121,8 +124,11 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet.process_record_dataset(dataset, is_training, batch_size, return resnet.process_record_dataset(dataset, is_training, batch_size,
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls) _NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
############################################################################### ###############################################################################
...@@ -189,7 +195,8 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -189,7 +195,8 @@ def cifar10_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
loss_filter_fn=loss_filter_fn) loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'])
def main(unused_argv): def main(unused_argv):
......
...@@ -70,13 +70,14 @@ class BaseTest(tf.test.TestCase): ...@@ -70,13 +70,14 @@ class BaseTest(tf.test.TestCase):
[_BATCH_SIZE], maxval=9, dtype=tf.int32) [_BATCH_SIZE], maxval=9, dtype=tf.int32)
return features, tf.one_hot(labels, 10) return features, tf.one_hot(labels, 10)
def cifar10_model_fn_helper(self, mode): def cifar10_model_fn_helper(self, mode, multi_gpu=False):
features, labels = self.input_fn() features, labels = self.input_fn()
spec = cifar10_main.cifar10_model_fn( spec = cifar10_main.cifar10_model_fn(
features, labels, mode, { features, labels, mode, {
'resnet_size': 32, 'resnet_size': 32,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'multi_gpu': multi_gpu
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -101,6 +102,9 @@ class BaseTest(tf.test.TestCase): ...@@ -101,6 +102,9 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_model_fn_train_mode(self): def test_cifar10_model_fn_train_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_cifar10_model_fn_train_mode_multi_gpu(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
def test_cifar10_model_fn_eval_mode(self): def test_cifar10_model_fn_eval_mode(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL)
......
...@@ -102,10 +102,9 @@ def parse_record(raw_record, is_training): ...@@ -102,10 +102,9 @@ def parse_record(raw_record, is_training):
# Note that the resulting image contains an unknown height and width # Note that the resulting image contains an unknown height and width
# that is set dynamically by decode_jpeg. In other words, the height # that is set dynamically by decode_jpeg. In other words, the height
# and width of image is unknown at compile-time. # and width of image is unknown at compile-time.
# Results in a 3-D int8 Tensor which we then convert to a float # Results in a 3-D int8 Tensor. This will be converted to a float later,
# with values ranging from [0, 1). # during resizing.
image = tf.image.decode_jpeg(image, channels=_NUM_CHANNELS) image = tf.image.decode_jpeg(image, channels=_NUM_CHANNELS)
image = tf.image.convert_image_dtype(image, tf.float32)
image = vgg_preprocessing.preprocess_image( image = vgg_preprocessing.preprocess_image(
image=image, image=image,
...@@ -120,7 +119,7 @@ def parse_record(raw_record, is_training): ...@@ -120,7 +119,7 @@ def parse_record(raw_record, is_training):
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1): num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
...@@ -130,6 +129,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -130,6 +129,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls: The number of records that are processed in parallel. num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores. sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -141,11 +143,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -141,11 +143,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet.process_record_dataset(dataset, is_training, batch_size, return resnet.process_record_dataset(dataset, is_training, batch_size,
_SHUFFLE_BUFFER, parse_record, num_epochs, num_parallel_calls) _SHUFFLE_BUFFER, parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
############################################################################### ###############################################################################
...@@ -225,7 +230,8 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -225,7 +230,8 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
momentum=0.9, momentum=0.9,
data_format=params['data_format'], data_format=params['data_format'],
loss_filter_fn=None) loss_filter_fn=None,
multi_gpu=params['multi_gpu'])
def main(unused_argv): def main(unused_argv):
......
...@@ -136,7 +136,7 @@ class BaseTest(tf.test.TestCase): ...@@ -136,7 +136,7 @@ class BaseTest(tf.test.TestCase):
return features, labels return features, labels
def resnet_model_fn_helper(self, mode): def resnet_model_fn_helper(self, mode, multi_gpu=False):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() tf.train.create_global_step()
...@@ -146,6 +146,7 @@ class BaseTest(tf.test.TestCase): ...@@ -146,6 +146,7 @@ class BaseTest(tf.test.TestCase):
'resnet_size': 50, 'resnet_size': 50,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'multi_gpu': multi_gpu,
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -170,6 +171,9 @@ class BaseTest(tf.test.TestCase): ...@@ -170,6 +171,9 @@ class BaseTest(tf.test.TestCase):
def test_resnet_model_fn_train_mode(self): def test_resnet_model_fn_train_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN) self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN)
def test_resnet_model_fn_train_mode_multi_gpu(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, multi_gpu=True)
def test_resnet_model_fn_eval_mode(self): def test_resnet_model_fn_eval_mode(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL)
...@@ -190,3 +194,4 @@ class BaseTest(tf.test.TestCase): ...@@ -190,3 +194,4 @@ class BaseTest(tf.test.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -46,9 +46,11 @@ _BATCH_NORM_EPSILON = 1e-5 ...@@ -46,9 +46,11 @@ _BATCH_NORM_EPSILON = 1e-5
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1): parse_record_fn, num_epochs=1, num_parallel_calls=1,
examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, parse each record into images and labels, """Given a Dataset with raw records, parse each record into images and labels,
and return an iterator over the records. and return an iterator over the records.
Args: Args:
dataset: A Dataset representing raw records dataset: A Dataset representing raw records
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
...@@ -62,6 +64,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -62,6 +64,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
num_parallel_calls: The number of records that are processed in parallel. num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores. sets, should be approximately the number of available CPU cores.
examples_per_epoch: The number of examples in the current set that
are processed each epoch. Note that this is only used for multi-GPU mode,
and only to handle what will eventually be handled inside of Estimator.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers (see below), and can be removed
when that is handled directly by Estimator.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
...@@ -78,6 +86,18 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -78,6 +86,18 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs. # dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Currently, if we are using multiple GPUs, we can't pass in uneven batches.
# (For example, if we have 4 GPUs, the number of examples in each batch
# must be divisible by 4.) We already ensured this for the batch_size, but
# we have to additionally ensure that any "leftover" examples-- the remainder
# examples (total examples % batch_size) that get called a batch for the very
# last batch of an epoch-- do not raise an error when we try to split them
# over the GPUs. This will likely be handled by Estimator during replication
# in the future, but for now, we just drop the leftovers here.
if multi_gpu:
total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
# Parse the raw records into images and labels # Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training), dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls) num_parallel_calls=num_parallel_calls)
...@@ -418,7 +438,7 @@ def learning_rate_with_decay( ...@@ -418,7 +438,7 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, loss_filter_fn=None): data_format, loss_filter_fn=None, multi_gpu=False):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers Initializes the ResnetModel representing the model layers
...@@ -446,6 +466,9 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -446,6 +466,9 @@ def resnet_model_fn(features, labels, mode, model_class,
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for
data-parallel distribution across multiple GPUs.
Returns: Returns:
EstimatorSpec parameterized according to the input params and the EstimatorSpec parameterized according to the input params and the
current mode. current mode.
...@@ -497,10 +520,12 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -497,10 +520,12 @@ def resnet_model_fn(features, labels, mode, model_class,
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum) momentum=momentum)
# Batch norm requires update ops to be added as a dependency to train_op # If we are running multi-GPU, we need to wrap the optimizer.
if multi_gpu:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops): train_op = tf.group(optimizer.minimize(loss, global_step), update_ops)
train_op = optimizer.minimize(loss, global_step)
else: else:
train_op = None train_op = None
...@@ -520,10 +545,45 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -520,10 +545,45 @@ def resnet_model_fn(features, labels, mode, model_class,
eval_metric_ops=metrics) eval_metric_ops=metrics)
def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of
available GPUs.
Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.
"""
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)
def resnet_main(flags, model_function, input_function): def resnet_main(flags, model_function, input_function):
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Set up a RunConfig to only save checkpoints once per training cycle. # Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9) run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
...@@ -532,6 +592,7 @@ def resnet_main(flags, model_function, input_function): ...@@ -532,6 +592,7 @@ def resnet_main(flags, model_function, input_function):
'resnet_size': flags.resnet_size, 'resnet_size': flags.resnet_size,
'data_format': flags.data_format, 'data_format': flags.data_format,
'batch_size': flags.batch_size, 'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu,
}) })
for _ in range(flags.train_epochs // flags.epochs_per_eval): for _ in range(flags.train_epochs // flags.epochs_per_eval):
...@@ -548,7 +609,8 @@ def resnet_main(flags, model_function, input_function): ...@@ -548,7 +609,8 @@ def resnet_main(flags, model_function, input_function):
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(True, flags.data_dir, flags.batch_size,
flags.epochs_per_eval, flags.num_parallel_calls) flags.epochs_per_eval, flags.num_parallel_calls,
flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook]) classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
...@@ -556,7 +618,7 @@ def resnet_main(flags, model_function, input_function): ...@@ -556,7 +618,7 @@ def resnet_main(flags, model_function, input_function):
# Evaluate the model and print results # Evaluate the model and print results
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size, return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls) 1, flags.num_parallel_calls, flags.multi_gpu)
eval_results = classifier.evaluate(input_fn=input_fn_eval) eval_results = classifier.evaluate(input_fn=input_fn_eval)
print(eval_results) print(eval_results)
...@@ -608,3 +670,8 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -608,3 +670,8 @@ class ResnetArgParser(argparse.ArgumentParser):
'is not always compatible with CPU. If left unspecified, ' 'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on ' 'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.') 'whether TensorFlow was built for CPU or GPU.')
self.add_argument(
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs. Note that this is '
'superseded by the --num_gpus flag.')
...@@ -34,9 +34,9 @@ from __future__ import print_function ...@@ -34,9 +34,9 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
_R_MEAN = 123.68 / 255 _R_MEAN = 123.68
_G_MEAN = 116.78 / 255 _G_MEAN = 116.78
_B_MEAN = 103.94 / 255 _B_MEAN = 103.94
_RESIZE_SIDE_MIN = 256 _RESIZE_SIDE_MIN = 256
_RESIZE_SIDE_MAX = 512 _RESIZE_SIDE_MAX = 512
...@@ -147,7 +147,7 @@ def _smallest_size_at_least(height, width, smallest_side): ...@@ -147,7 +147,7 @@ def _smallest_size_at_least(height, width, smallest_side):
Returns: Returns:
new_height: an int32 scalar tensor indicating the new height. new_height: an int32 scalar tensor indicating the new height.
new_width: and int32 scalar tensor indicating the new width. new_width: an int32 scalar tensor indicating the new width.
""" """
smallest_side = tf.cast(smallest_side, tf.float32) smallest_side = tf.cast(smallest_side, tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment