Unverified Commit 32aa6563 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Resnet distribution strategies (#3887)

* begin transfer from contrib fork

more changes to resnet_run_loop

use AUTOTUNE in prefetch

first pass at resnet with functional distribution strategies

fix syntax error

delint

aesthetic tweaks

delint and fix typos

rip multi_gpu flag out of resnet entirely. Subject to saved model load verification

update cifar10 and imagenet tests to reflect that the model function no longer need to know about multi_gpu

fix imagenet test

start addressing PR comments

more PR response work

* misc tweaks

* add a comment

* final pr tweaks

* fix parsers
parent ad7755c8
...@@ -244,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser): ...@@ -244,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(MNISTArgParser, self).__init__(parents=[ super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=True, num_gpu=False),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
]) ])
......
...@@ -59,3 +59,13 @@ Other versions and formats: ...@@ -59,3 +59,13 @@ Other versions and formats:
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
...@@ -103,8 +103,7 @@ def preprocess_image(image, is_training): ...@@ -103,8 +103,7 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -112,12 +111,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -112,12 +111,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -125,12 +118,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -125,12 +118,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'], dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls, parse_record, num_epochs,
examples_per_epoch=num_images, multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -221,7 +212,6 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -221,7 +212,6 @@ def cifar10_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn, loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -71,87 +71,63 @@ class BaseTest(tf.test.TestCase): ...@@ -71,87 +71,63 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False): def cifar10_model_fn_helper(self, mode, version, dtype):
with tf.Graph().as_default() as g: input_fn = cifar10_main.get_synth_input_fn()
input_fn = cifar10_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = cifar10_main.cifar10_model_fn(
spec = cifar10_main.cifar10_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 32,
'resnet_size': 32, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'version': version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, 10))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, 10)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def _test_cifar10model_shape(self, version): def _test_cifar10model_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -154,8 +154,7 @@ def parse_record(raw_record, is_training): ...@@ -154,8 +154,7 @@ def parse_record(raw_record, is_training):
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -163,12 +162,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -163,12 +162,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -180,15 +173,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -180,15 +173,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record, dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images, num_epochs
multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -180,88 +180,66 @@ class BaseTest(tf.test.TestCase): ...@@ -180,88 +180,66 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, version=2, with_gpu=True)
def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu): def resnet_model_fn_helper(self, mode, version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
with tf.Graph().as_default() as g: tf.train.create_global_step()
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
input_fn = imagenet_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = imagenet_main.imagenet_model_fn(
spec = imagenet_main.imagenet_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 50,
'resnet_size': 50, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'version': version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu,
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, _LABEL_CLASSES))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, _LABEL_CLASSES)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def _test_imagenetmodel_shape(self, version): def _test_imagenetmodel_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -40,8 +40,7 @@ from official.utils.misc import model_helpers ...@@ -40,8 +40,7 @@ from official.utils.misc import model_helpers
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1, parse_record_fn, num_epochs=1):
examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -54,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -54,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
examples_per_epoch: The number of examples in the current set that
are processed each epoch. Note that this is only used for multi-GPU mode,
and only to handle what will eventually be handled inside of Estimator.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers (see below), and can be removed
when that is handled directly by Estimator.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# We prefetch a batch at a time, This can help smooth out the time taken to # We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing. # load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
...@@ -79,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -79,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs. # dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Currently, if we are using multiple GPUs, we can't pass in uneven batches. # Parse the raw records into images and labels. Testing has shown that setting
# (For example, if we have 4 GPUs, the number of examples in each batch # num_parallel_batches > 1 produces no improvement in throughput, since
# must be divisible by 4.) We already ensured this for the batch_size, but # batch_size is almost always much greater than the number of CPU cores.
# we have to additionally ensure that any "leftover" examples-- the remainder dataset = dataset.apply(
# examples (total examples % batch_size) that get called a batch for the very tf.contrib.data.map_and_batch(
# last batch of an epoch-- do not raise an error when we try to split them lambda value: parse_record_fn(value, is_training),
# over the GPUs. This will likely be handled by Estimator during replication batch_size=batch_size,
# in the future, but for now, we just drop the leftovers here. num_parallel_batches=1))
if multi_gpu:
total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
dataset = dataset.prefetch(1) # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset return dataset
...@@ -123,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -123,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32) labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat() return tf.data.Dataset.from_tensors((images, labels)).repeat()
...@@ -171,8 +155,7 @@ def learning_rate_with_decay( ...@@ -171,8 +155,7 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_scale, data_format, version, loss_scale, loss_filter_fn=None,
loss_filter_fn=None, multi_gpu=False,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
...@@ -205,8 +188,6 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -205,8 +188,6 @@ def resnet_model_fn(features, labels, mode, model_class,
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for
data-parallel distribution across multiple GPUs.
dtype: the TensorFlow dtype to use for calculations. dtype: the TensorFlow dtype to use for calculations.
Returns: Returns:
...@@ -275,11 +256,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -275,11 +256,8 @@ def resnet_model_fn(features, labels, mode, model_class,
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum) momentum=momentum
)
# If we are running multi-GPU, we need to wrap the optimizer.
if multi_gpu:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
if loss_scale != 1: if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are # When computing fp16 gradients, often intermediate tensor values are
...@@ -300,8 +278,14 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -300,8 +278,14 @@ def resnet_model_fn(features, labels, mode, model_class,
else: else:
train_op = None train_op = None
accuracy = tf.metrics.accuracy( if not tf.contrib.distribute.has_distribution_strategy():
tf.argmax(labels, axis=1), predictions['classes']) accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
accuracy = (tf.no_op(), tf.constant(0))
metrics = {'accuracy': accuracy} metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes # Create a tensor named train_accuracy for logging purposes
...@@ -316,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -316,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class,
eval_metric_ops=metrics) eval_metric_ops=metrics)
def validate_batch_size_for_multi_gpu(batch_size): def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. """For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn Note that this should eventually be handled by DistributionStrategies
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args: Args:
batch_size: the number of examples processed in each training batch. batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises: Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid. ValueError: if batch_size is not divisible by number of devices
""" """
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top if num_gpus <= 1:
return batch_size
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ('When running with multiple GPUs, batch size ' err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. ' 'must be a multiple of the number of available GPUs. Found {} '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.' 'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus)
def resnet_main(flags, model_function, input_function, shape=None): def resnet_main(flags, model_function, input_function, shape=None):
...@@ -364,16 +349,6 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -364,16 +349,6 @@ def resnet_main(flags, model_function, input_function, shape=None):
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Create session config based on values of inter_op_parallelism_threads and # Create session config based on values of inter_op_parallelism_threads and
# intra_op_parallelism_threads. Note that we default to having # intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not # allow_soft_placement = True, which is required for multi-GPU and not
...@@ -383,16 +358,24 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -383,16 +358,24 @@ def resnet_main(flags, model_function, input_function, shape=None):
intra_op_parallelism_threads=flags.intra_op_parallelism_threads, intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
# Set up a RunConfig to save checkpoint and set session config. if flags.num_gpus == 0:
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9, distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
session_config=session_config) elif flags.num_gpus == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
distribution = tf.contrib.distribute.MirroredStrategy(
num_gpus=flags.num_gpus
)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config, model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={ params={
'resnet_size': flags.resnet_size, 'resnet_size': flags.resnet_size,
'data_format': flags.data_format, 'data_format': flags.data_format,
'batch_size': flags.batch_size, 'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu,
'version': flags.version, 'version': flags.version,
'loss_scale': flags.loss_scale, 'loss_scale': flags.loss_scale,
'dtype': flags.dtype 'dtype': flags.dtype
...@@ -413,9 +396,12 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -413,9 +396,12 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting a training cycle.') print('Starting a training cycle.')
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(
flags.epochs_between_evals, is_training=True,
flags.num_parallel_calls, flags.multi_gpu) data_dir=flags.data_dir,
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=flags.epochs_between_evals,
)
classifier.train(input_fn=input_fn_train, hooks=train_hooks, classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps) max_steps=flags.max_train_steps)
...@@ -423,8 +409,12 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -423,8 +409,12 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size, return input_function(
1, flags.num_parallel_calls, flags.multi_gpu) is_training=False,
data_dir=flags.data_dir,
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=1,
)
# flags.max_train_steps is generally associated with testing and profiling. # flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will # As a result it is frequently called with synthetic data, which will
...@@ -444,31 +434,19 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -444,31 +434,19 @@ def resnet_main(flags, model_function, input_function, shape=None):
break break
if flags.export_dir is not None: if flags.export_dir is not None:
warn_on_multi_gpu_export(flags.multi_gpu)
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags.batch_size) shape, batch_size=flags.batch_size)
classifier.export_savedmodel(flags.export_dir, input_receiver_fn) classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False):
"""For the time being, multi-GPU mode does not play nicely with exporting."""
if multi_gpu:
tf.logging.warning(
'You are exporting a SavedModel while in multi-GPU mode. Note that '
'the resulting SavedModel will require the same GPUs be available.'
'If you wish to serve the SavedModel from a different device, '
'try exporting the SavedModel with multi-GPU mode turned off.')
class ResnetArgParser(argparse.ArgumentParser): class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.""" """Arguments for configuring and running a Resnet Model."""
def __init__(self, resnet_size_choices=None): def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__(parents=[ super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=False),
parsers.PerformanceParser(), parsers.PerformanceParser(num_parallel_calls=False),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
parsers.BenchmarkParser(), parsers.BenchmarkParser(),
......
...@@ -101,15 +101,16 @@ class BaseParser(argparse.ArgumentParser): ...@@ -101,15 +101,16 @@ class BaseParser(argparse.ArgumentParser):
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training. eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the global batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True, stop_threshold=True, batch_size=True,
hooks=True): multi_gpu=False, num_gpu=True, hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -154,16 +155,30 @@ class BaseParser(argparse.ArgumentParser): ...@@ -154,16 +155,30 @@ class BaseParser(argparse.ArgumentParser):
if batch_size: if batch_size:
self.add_argument( self.add_argument(
"--batch_size", "-bs", type=int, default=32, "--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.", help="[default: %(default)s] Global batch size for training and "
"evaluation.",
metavar="<BS>" metavar="<BS>"
) )
assert not (multi_gpu and num_gpu)
if multi_gpu: if multi_gpu:
self.add_argument( self.add_argument(
"--multi_gpu", action="store_true", "--multi_gpu", action="store_true",
help="If set, run across all available GPUs." help="If set, run across all available GPUs."
) )
if num_gpu:
self.add_argument(
"--num_gpus", "-ng",
type=int,
default=1 if tf.test.is_built_with_cuda() else 0,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API. The default is 1 if TensorFlow was"
"built with CUDA, and 0 otherwise.",
metavar="<NG>"
)
if hooks: if hooks:
self.add_argument( self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"], "--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
......
...@@ -26,7 +26,7 @@ class TestParser(argparse.ArgumentParser): ...@@ -26,7 +26,7 @@ class TestParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(TestParser, self).__init__(parents=[ super(TestParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=True, num_gpu=False),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True, parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True), intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True), parsers.ImageModelParser(data_format=True),
......
...@@ -221,7 +221,8 @@ class WideDeepArgParser(argparse.ArgumentParser): ...@@ -221,7 +221,8 @@ class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model.""" """Argument parser for running the wide deep model."""
def __init__(self): def __init__(self):
super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()]) super(WideDeepArgParser, self).__init__(parents=[
parsers.BaseParser(multi_gpu=False, num_gpu=False)])
self.add_argument( self.add_argument(
'--model_type', '-mt', type=str, default='wide_deep', '--model_type', '-mt', type=str, default='wide_deep',
choices=['wide', 'deep', 'wide_deep'], choices=['wide', 'deep', 'wide_deep'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment