Unverified Commit 18d05ad3 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Restore ResNet Distribution Strategies (#4134)

* Revert 823da318. This restores distribution strategies for resnet.

This commit is not a direct revert due to significant merge conflict
resolution.

* fix flags test

* npc is no longer used in resnet
parent 51a2b441
...@@ -87,7 +87,7 @@ def create_model(data_format): ...@@ -87,7 +87,7 @@ def create_model(data_format):
def define_mnist_flags(): def define_mnist_flags():
flags_core.define_base() flags_core.define_base(multi_gpu=True, num_gpu=False)
flags_core.define_image() flags_core.define_image()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data', flags_core.set_defaults(data_dir='/tmp/mnist_data',
......
...@@ -59,3 +59,13 @@ Other versions and formats: ...@@ -59,3 +59,13 @@ Other versions and formats:
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
...@@ -105,8 +105,7 @@ def preprocess_image(image, is_training): ...@@ -105,8 +105,7 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -114,12 +113,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -114,12 +113,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -127,12 +120,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -127,12 +120,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'], dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls, parse_record, num_epochs,
examples_per_epoch=num_images, multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -221,7 +212,6 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -221,7 +212,6 @@ def cifar10_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn, loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -76,87 +76,63 @@ class BaseTest(tf.test.TestCase): ...@@ -76,87 +76,63 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False): def cifar10_model_fn_helper(self, mode, version, dtype):
with tf.Graph().as_default() as g: input_fn = cifar10_main.get_synth_input_fn()
input_fn = cifar10_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = cifar10_main.cifar10_model_fn(
spec = cifar10_main.cifar10_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 32,
'resnet_size': 32, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'version': version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, 10))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, 10)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def _test_cifar10model_shape(self, version): def _test_cifar10model_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -156,8 +156,7 @@ def parse_record(raw_record, is_training): ...@@ -156,8 +156,7 @@ def parse_record(raw_record, is_training):
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -165,12 +164,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -165,12 +164,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -182,15 +175,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -182,15 +175,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record, dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images, num_epochs
multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -300,7 +291,6 @@ def imagenet_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -185,88 +185,66 @@ class BaseTest(tf.test.TestCase): ...@@ -185,88 +185,66 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, version=2, with_gpu=True)
def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu): def resnet_model_fn_helper(self, mode, version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
with tf.Graph().as_default() as g: tf.train.create_global_step()
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
input_fn = imagenet_main.get_synth_input_fn() dataset = input_fn(True, '', _BATCH_SIZE)
dataset = input_fn(True, '', _BATCH_SIZE) iterator = dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next()
features, labels = iterator.get_next() spec = imagenet_main.imagenet_model_fn(
spec = imagenet_main.imagenet_model_fn( features, labels, mode, {
features, labels, mode, { 'dtype': dtype,
'dtype': dtype, 'resnet_size': 50,
'resnet_size': 50, 'data_format': 'channels_last',
'data_format': 'channels_last', 'batch_size': _BATCH_SIZE,
'batch_size': _BATCH_SIZE, 'version': version,
'version': version, 'loss_scale': 128 if dtype == tf.float16 else 1,
'loss_scale': 128 if dtype == tf.float16 else 1, })
'multi_gpu': multi_gpu,
}) predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
predictions = spec.predictions (_BATCH_SIZE, _LABEL_CLASSES))
self.assertAllEqual(predictions['probabilities'].shape, self.assertEqual(predictions['probabilities'].dtype, tf.float32)
(_BATCH_SIZE, _LABEL_CLASSES)) self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(loss.shape, ())
loss = spec.loss self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def _test_imagenetmodel_shape(self, version): def _test_imagenetmodel_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -36,15 +36,11 @@ from official.utils.logs import logger ...@@ -36,15 +36,11 @@ from official.utils.logs import logger
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
################################################################################ ################################################################################
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1, parse_record_fn, num_epochs=1):
examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -57,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -57,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
examples_per_epoch: The number of examples in the current set that
are processed each epoch. Note that this is only used for multi-GPU mode,
and only to handle what will eventually be handled inside of Estimator.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers (see below), and can be removed
when that is handled directly by Estimator.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# We prefetch a batch at a time, This can help smooth out the time taken to # We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing. # load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
...@@ -82,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -82,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs. # dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Currently, if we are using multiple GPUs, we can't pass in uneven batches. # Parse the raw records into images and labels. Testing has shown that setting
# (For example, if we have 4 GPUs, the number of examples in each batch # num_parallel_batches > 1 produces no improvement in throughput, since
# must be divisible by 4.) We already ensured this for the batch_size, but # batch_size is almost always much greater than the number of CPU cores.
# we have to additionally ensure that any "leftover" examples-- the remainder dataset = dataset.apply(
# examples (total examples % batch_size) that get called a batch for the very tf.contrib.data.map_and_batch(
# last batch of an epoch-- do not raise an error when we try to split them lambda value: parse_record_fn(value, is_training),
# over the GPUs. This will likely be handled by Estimator during replication batch_size=batch_size,
# in the future, but for now, we just drop the leftovers here. num_parallel_batches=1))
if multi_gpu:
total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
dataset = dataset.prefetch(1) # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset return dataset
...@@ -126,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -126,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32) labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat() return tf.data.Dataset.from_tensors((images, labels)).repeat()
...@@ -174,8 +155,7 @@ def learning_rate_with_decay( ...@@ -174,8 +155,7 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_scale, data_format, version, loss_scale, loss_filter_fn=None,
loss_filter_fn=None, multi_gpu=False,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
...@@ -208,8 +188,6 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -208,8 +188,6 @@ def resnet_model_fn(features, labels, mode, model_class,
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for
data-parallel distribution across multiple GPUs.
dtype: the TensorFlow dtype to use for calculations. dtype: the TensorFlow dtype to use for calculations.
Returns: Returns:
...@@ -278,11 +256,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -278,11 +256,8 @@ def resnet_model_fn(features, labels, mode, model_class,
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum) momentum=momentum
)
# If we are running multi-GPU, we need to wrap the optimizer.
if multi_gpu:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
if loss_scale != 1: if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are # When computing fp16 gradients, often intermediate tensor values are
...@@ -303,8 +278,14 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -303,8 +278,14 @@ def resnet_model_fn(features, labels, mode, model_class,
else: else:
train_op = None train_op = None
accuracy = tf.metrics.accuracy( if not tf.contrib.distribute.has_distribution_strategy():
tf.argmax(labels, axis=1), predictions['classes']) accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
accuracy = (tf.no_op(), tf.constant(0))
metrics = {'accuracy': accuracy} metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes # Create a tensor named train_accuracy for logging purposes
...@@ -319,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -319,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class,
eval_metric_ops=metrics) eval_metric_ops=metrics)
def validate_batch_size_for_multi_gpu(batch_size): def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. """For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn Note that this should eventually be handled by DistributionStrategies
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args: Args:
batch_size: the number of examples processed in each training batch. batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises: Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid. ValueError: if batch_size is not divisible by number of devices
""" """
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top if num_gpus <= 1:
return batch_size
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ('When running with multiple GPUs, batch size ' err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. ' 'must be a multiple of the number of available GPUs. Found {} '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.' 'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus)
def resnet_main(flags_obj, model_function, input_function, shape=None): def resnet_main(flags_obj, model_function, input_function, shape=None):
...@@ -361,22 +343,12 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): ...@@ -361,22 +343,12 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
dataset that the estimator can train on. This will be wrapped with dataset that the estimator can train on. This will be wrapped with
all the relevant flags for running and passed to estimator. all the relevant flags for running and passed to estimator.
shape: list of ints representing the shape of the images used for training. shape: list of ints representing the shape of the images used for training.
This is only used if flags.export_dir is passed. This is only used if flags_obj.export_dir is passed.
""" """
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags_obj.multi_gpu:
validate_batch_size_for_multi_gpu(flags_obj.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Create session config based on values of inter_op_parallelism_threads and # Create session config based on values of inter_op_parallelism_threads and
# intra_op_parallelism_threads. Note that we default to having # intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not # allow_soft_placement = True, which is required for multi-GPU and not
...@@ -386,16 +358,24 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): ...@@ -386,16 +358,24 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
# Set up a RunConfig to save checkpoint and set session config. if flags_core.get_num_gpus(flags_obj) == 0:
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9, distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
session_config=session_config) elif flags_core.get_num_gpus(flags_obj) == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
distribution = tf.contrib.distribute.MirroredStrategy(
num_gpus=flags_core.get_num_gpus(flags_obj)
)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
params={ params={
'resnet_size': int(flags_obj.resnet_size), 'resnet_size': int(flags_obj.resnet_size),
'data_format': flags_obj.data_format, 'data_format': flags_obj.data_format,
'batch_size': flags_obj.batch_size, 'batch_size': flags_obj.batch_size,
'multi_gpu': flags_obj.multi_gpu,
'version': int(flags_obj.version), 'version': int(flags_obj.version),
'loss_scale': flags_core.get_loss_scale(flags_obj), 'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags_core.get_tf_dtype(flags_obj) 'dtype': flags_core.get_tf_dtype(flags_obj)
...@@ -410,13 +390,18 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): ...@@ -410,13 +390,18 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
benchmark_log_dir=flags_obj.benchmark_log_dir) benchmark_log_dir=flags_obj.benchmark_log_dir)
def input_fn_train(): def input_fn_train():
return input_function(True, flags_obj.data_dir, flags_obj.batch_size, return input_function(
flags_obj.epochs_between_evals, is_training=True, data_dir=flags_obj.data_dir,
flags_obj.num_parallel_calls, flags_obj.multi_gpu) batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=flags_obj.epochs_between_evals)
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags_obj.data_dir, flags_obj.batch_size, return input_function(
1, flags_obj.num_parallel_calls, flags_obj.multi_gpu) is_training=False, data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1)
total_training_cycle = (flags_obj.train_epochs // total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals) flags_obj.epochs_between_evals)
...@@ -428,10 +413,11 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): ...@@ -428,10 +413,11 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
max_steps=flags_obj.max_train_steps) max_steps=flags_obj.max_train_steps)
tf.logging.info('Starting to evaluate.') tf.logging.info('Starting to evaluate.')
# flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will # flags_obj.max_train_steps is generally associated with testing and
# iterate forever. Passing steps=flags.max_train_steps allows the eval # profiling. As a result it is frequently called with synthetic data, which
# (which is generally unimportant in those circumstances) to terminate. # will iterate forever. Passing steps=flags_obj.max_train_steps allows the
# eval (which is generally unimportant in those circumstances) to terminate.
# Note that eval will run for max_train_steps each loop, regardless of the # Note that eval will run for max_train_steps each loop, regardless of the
# global_step count. # global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval, eval_results = classifier.evaluate(input_fn=input_fn_eval,
...@@ -444,28 +430,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None): ...@@ -444,28 +430,16 @@ def resnet_main(flags_obj, model_function, input_function, shape=None):
break break
if flags_obj.export_dir is not None: if flags_obj.export_dir is not None:
warn_on_multi_gpu_export(flags_obj.multi_gpu)
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags_obj.batch_size) shape, batch_size=flags_obj.batch_size)
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False):
"""For the time being, multi-GPU mode does not play nicely with exporting."""
if multi_gpu:
tf.logging.warning(
'You are exporting a SavedModel while in multi-GPU mode. Note that '
'the resulting SavedModel will require the same GPUs be available.'
'If you wish to serve the SavedModel from a different device, '
'try exporting the SavedModel with multi-GPU mode turned off.')
def define_resnet_flags(resnet_size_choices=None): def define_resnet_flags(resnet_size_choices=None):
"""Add flags and validators for ResNet.""" """Add flags and validators for ResNet."""
flags_core.define_base() flags_core.define_base()
flags_core.define_performance() flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core) flags.adopt_module_key_flags(flags_core)
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl import flags from absl import flags
import tensorflow as tf
from official.utils.flags._conventions import help_wrap from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
...@@ -26,7 +27,7 @@ from official.utils.logs import hooks_helper ...@@ -26,7 +27,7 @@ from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, train_epochs=True, def define_base(data_dir=True, model_dir=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True, epochs_between_evals=True, stop_threshold=True, batch_size=True,
multi_gpu=True, hooks=True, export_dir=True): multi_gpu=False, num_gpu=True, hooks=True, export_dir=True):
"""Register base flags. """Register base flags.
Args: Args:
...@@ -38,6 +39,7 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True, ...@@ -38,6 +39,7 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True,
eval metric which should trigger the end of training. eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported. export_dir: Create a flag to specify where a SavedModel should be exported.
...@@ -85,12 +87,22 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True, ...@@ -85,12 +87,22 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True,
help=help_wrap("Batch size for training and evaluation.")) help=help_wrap("Batch size for training and evaluation."))
key_flags.append("batch_size") key_flags.append("batch_size")
assert not (multi_gpu and num_gpu)
if multi_gpu: if multi_gpu:
flags.DEFINE_bool( flags.DEFINE_bool(
name="multi_gpu", default=False, name="multi_gpu", default=False,
help=help_wrap("If set, run across all available GPUs.")) help=help_wrap("If set, run across all available GPUs."))
key_flags.append("multi_gpu") key_flags.append("multi_gpu")
if num_gpu:
flags.DEFINE_integer(
name="num_gpus", short_name="ng",
default=1 if tf.test.is_gpu_available() else 0,
help=help_wrap(
"How many GPUs to use with the DistributionStrategies API. The "
"default is 1 if TensorFlow can detect a GPU, and 0 otherwise."))
if hooks: if hooks:
# Construct a pretty summary of hooks. # Construct a pretty summary of hooks.
hook_list_str = ( hook_list_str = (
...@@ -116,3 +128,13 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True, ...@@ -116,3 +128,13 @@ def define_base(data_dir=True, model_dir=True, train_epochs=True,
key_flags.append("export_dir") key_flags.append("export_dir")
return key_flags return key_flags
def get_num_gpus(flags_obj):
"""Treat num_gpus=-1 as 'use all'."""
if flags_obj.num_gpus != -1:
return flags_obj.num_gpus
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
local_device_protos = device_lib.list_local_devices()
return sum([1 for d in local_device_protos if d.device_type == "GPU"])
...@@ -79,5 +79,6 @@ define_performance = register_key_flags_in_core(_performance.define_performance) ...@@ -79,5 +79,6 @@ define_performance = register_key_flags_in_core(_performance.define_performance)
help_wrap = _conventions.help_wrap help_wrap = _conventions.help_wrap
get_num_gpus = _base.get_num_gpus
get_tf_dtype = _performance.get_tf_dtype get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale get_loss_scale = _performance.get_loss_scale
...@@ -22,7 +22,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp ...@@ -22,7 +22,7 @@ from official.utils.flags import core as flags_core # pylint: disable=g-bad-imp
def define_flags(): def define_flags():
flags_core.define_base() flags_core.define_base(multi_gpu=True, num_gpu=False)
flags_core.define_performance() flags_core.define_performance()
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment