Unverified Commit 823da318 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Revert "Resnet distribution strategies (#3887)" (#4033)

This reverts commit 32aa6563.
parent 07a7584e
...@@ -244,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser): ...@@ -244,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(MNISTArgParser, self).__init__(parents=[ super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(multi_gpu=True, num_gpu=False), parsers.BaseParser(),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
]) ])
......
...@@ -59,13 +59,3 @@ Other versions and formats: ...@@ -59,13 +59,3 @@ Other versions and formats:
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
...@@ -103,7 +103,8 @@ def preprocess_image(image, is_training): ...@@ -103,7 +103,8 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -111,6 +112,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -111,6 +112,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -118,10 +125,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -118,10 +125,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'], dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, parse_record, num_epochs, num_parallel_calls,
) examples_per_epoch=num_images, multi_gpu=multi_gpu)
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -212,6 +221,7 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -212,6 +221,7 @@ def cifar10_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn, loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -71,7 +71,8 @@ class BaseTest(tf.test.TestCase): ...@@ -71,7 +71,8 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, version, dtype): def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
with tf.Graph().as_default() as g:
input_fn = cifar10_main.get_synth_input_fn() input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
...@@ -84,6 +85,7 @@ class BaseTest(tf.test.TestCase): ...@@ -84,6 +85,7 @@ class BaseTest(tf.test.TestCase):
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1, 'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -105,29 +107,51 @@ class BaseTest(tf.test.TestCase): ...@@ -105,29 +107,51 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
dtype=tf.float32)
def test_cifar10_model_fn_trainmode__v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
dtype=tf.float32) multi_gpu=True)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
dtype=tf.float32)
def _test_cifar10model_shape(self, version): def _test_cifar10model_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -154,7 +154,8 @@ def parse_record(raw_record, is_training): ...@@ -154,7 +154,8 @@ def parse_record(raw_record, is_training):
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1): def input_fn(is_training, data_dir, batch_size, num_epochs=1,
num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -162,6 +163,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -162,6 +163,12 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -173,13 +180,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1): ...@@ -173,13 +180,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record, dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs num_epochs, num_parallel_calls, examples_per_epoch=num_images,
) multi_gpu=multi_gpu)
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -291,6 +300,7 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -291,6 +300,7 @@ def imagenet_model_fn(features, labels, mode, params):
version=params['version'], version=params['version'],
loss_scale=params['loss_scale'], loss_scale=params['loss_scale'],
loss_filter_fn=None, loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype'] dtype=params['dtype']
) )
......
...@@ -180,8 +180,9 @@ class BaseTest(tf.test.TestCase): ...@@ -180,8 +180,9 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, dtype): def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
with tf.Graph().as_default() as g:
tf.train.create_global_step() tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn() input_fn = imagenet_main.get_synth_input_fn()
...@@ -196,6 +197,7 @@ class BaseTest(tf.test.TestCase): ...@@ -196,6 +197,7 @@ class BaseTest(tf.test.TestCase):
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1, 'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu,
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -217,29 +219,49 @@ class BaseTest(tf.test.TestCase): ...@@ -217,29 +219,49 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
dtype=tf.float32)
def test_resnet_model_fn_train_mode_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
dtype=tf.float32) multi_gpu=True)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1)
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2)
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1)
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
dtype=tf.float32)
def _test_imagenetmodel_shape(self, version): def _test_imagenetmodel_shape(self, version):
batch_size = 135 batch_size = 135
......
...@@ -40,7 +40,8 @@ from official.utils.misc import model_helpers ...@@ -40,7 +40,8 @@ from official.utils.misc import model_helpers
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1): parse_record_fn, num_epochs=1, num_parallel_calls=1,
examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -53,11 +54,19 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -53,11 +54,19 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
examples_per_epoch: The number of examples in the current set that
are processed each epoch. Note that this is only used for multi-GPU mode,
and only to handle what will eventually be handled inside of Estimator.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers (see below), and can be removed
when that is handled directly by Estimator.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# We prefetch a batch at a time, This can help smooth out the time taken to # We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing. # load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
...@@ -70,22 +79,29 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -70,22 +79,29 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs. # dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Parse the raw records into images and labels. Testing has shown that setting # Currently, if we are using multiple GPUs, we can't pass in uneven batches.
# num_parallel_batches > 1 produces no improvement in throughput, since # (For example, if we have 4 GPUs, the number of examples in each batch
# batch_size is almost always much greater than the number of CPU cores. # must be divisible by 4.) We already ensured this for the batch_size, but
dataset = dataset.apply( # we have to additionally ensure that any "leftover" examples-- the remainder
tf.contrib.data.map_and_batch( # examples (total examples % batch_size) that get called a batch for the very
lambda value: parse_record_fn(value, is_training), # last batch of an epoch-- do not raise an error when we try to split them
batch_size=batch_size, # over the GPUs. This will likely be handled by Estimator during replication
num_parallel_batches=1)) # in the future, but for now, we just drop the leftovers here.
if multi_gpu:
total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE # critical training path.
# allows DistributionStrategies to adjust how many batches to fetch based dataset = dataset.prefetch(1)
# on how many devices are present.
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset return dataset
...@@ -107,7 +123,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -107,7 +123,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32) labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat() return tf.data.Dataset.from_tensors((images, labels)).repeat()
...@@ -155,7 +171,8 @@ def learning_rate_with_decay( ...@@ -155,7 +171,8 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_scale, loss_filter_fn=None, data_format, version, loss_scale,
loss_filter_fn=None, multi_gpu=False,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
...@@ -188,6 +205,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -188,6 +205,8 @@ def resnet_model_fn(features, labels, mode, model_class,
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for
data-parallel distribution across multiple GPUs.
dtype: the TensorFlow dtype to use for calculations. dtype: the TensorFlow dtype to use for calculations.
Returns: Returns:
...@@ -256,8 +275,11 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -256,8 +275,11 @@ def resnet_model_fn(features, labels, mode, model_class,
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum momentum=momentum)
)
# If we are running multi-GPU, we need to wrap the optimizer.
if multi_gpu:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
if loss_scale != 1: if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are # When computing fp16 gradients, often intermediate tensor values are
...@@ -278,14 +300,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -278,14 +300,8 @@ def resnet_model_fn(features, labels, mode, model_class,
else: else:
train_op = None train_op = None
if not tf.contrib.distribute.has_distribution_strategy():
accuracy = tf.metrics.accuracy( accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes']) tf.argmax(labels, axis=1), predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
accuracy = (tf.no_op(), tf.constant(0))
metrics = {'accuracy': accuracy} metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes # Create a tensor named train_accuracy for logging purposes
...@@ -300,35 +316,34 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -300,35 +316,34 @@ def resnet_model_fn(features, labels, mode, model_class,
eval_metric_ops=metrics) eval_metric_ops=metrics)
def per_device_batch_size(batch_size, num_gpus): def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. """For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by DistributionStrategies Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args: Args:
batch_size: Global batch size to be divided among devices. This should be batch_size: the number of examples processed in each training batch.
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises: Raises:
ValueError: if batch_size is not divisible by number of devices ValueError: if no GPUs are found, or selected batch_size is invalid.
""" """
if num_gpus <= 1: from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
return batch_size
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ('When running with multiple GPUs, batch size ' err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. Found {} ' 'must be a multiple of the number of available GPUs. '
'GPUs with a batch size of {}; try --batch_size={} instead.' 'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus)
def resnet_main(flags, model_function, input_function, shape=None): def resnet_main(flags, model_function, input_function, shape=None):
...@@ -349,6 +364,16 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -349,6 +364,16 @@ def resnet_main(flags, model_function, input_function, shape=None):
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Create session config based on values of inter_op_parallelism_threads and # Create session config based on values of inter_op_parallelism_threads and
# intra_op_parallelism_threads. Note that we default to having # intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not # allow_soft_placement = True, which is required for multi-GPU and not
...@@ -358,24 +383,16 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -358,24 +383,16 @@ def resnet_main(flags, model_function, input_function, shape=None):
intra_op_parallelism_threads=flags.intra_op_parallelism_threads, intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
if flags.num_gpus == 0: # Set up a RunConfig to save checkpoint and set session config.
distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0') run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
elif flags.num_gpus == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
distribution = tf.contrib.distribute.MirroredStrategy(
num_gpus=flags.num_gpus
)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
session_config=session_config) session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config, model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={ params={
'resnet_size': flags.resnet_size, 'resnet_size': flags.resnet_size,
'data_format': flags.data_format, 'data_format': flags.data_format,
'batch_size': flags.batch_size, 'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu,
'version': flags.version, 'version': flags.version,
'loss_scale': flags.loss_scale, 'loss_scale': flags.loss_scale,
'dtype': flags.dtype 'dtype': flags.dtype
...@@ -396,12 +413,9 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -396,12 +413,9 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting a training cycle.') print('Starting a training cycle.')
def input_fn_train(): def input_fn_train():
return input_function( return input_function(True, flags.data_dir, flags.batch_size,
is_training=True, flags.epochs_between_evals,
data_dir=flags.data_dir, flags.num_parallel_calls, flags.multi_gpu)
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=flags.epochs_between_evals,
)
classifier.train(input_fn=input_fn_train, hooks=train_hooks, classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps) max_steps=flags.max_train_steps)
...@@ -409,12 +423,8 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -409,12 +423,8 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
def input_fn_eval(): def input_fn_eval():
return input_function( return input_function(False, flags.data_dir, flags.batch_size,
is_training=False, 1, flags.num_parallel_calls, flags.multi_gpu)
data_dir=flags.data_dir,
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=1,
)
# flags.max_train_steps is generally associated with testing and profiling. # flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will # As a result it is frequently called with synthetic data, which will
...@@ -434,19 +444,31 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -434,19 +444,31 @@ def resnet_main(flags, model_function, input_function, shape=None):
break break
if flags.export_dir is not None: if flags.export_dir is not None:
warn_on_multi_gpu_export(flags.multi_gpu)
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags.batch_size) shape, batch_size=flags.batch_size)
classifier.export_savedmodel(flags.export_dir, input_receiver_fn) classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False):
"""For the time being, multi-GPU mode does not play nicely with exporting."""
if multi_gpu:
tf.logging.warning(
'You are exporting a SavedModel while in multi-GPU mode. Note that '
'the resulting SavedModel will require the same GPUs be available.'
'If you wish to serve the SavedModel from a different device, '
'try exporting the SavedModel with multi-GPU mode turned off.')
class ResnetArgParser(argparse.ArgumentParser): class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.""" """Arguments for configuring and running a Resnet Model."""
def __init__(self, resnet_size_choices=None): def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__(parents=[ super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(multi_gpu=False), parsers.BaseParser(),
parsers.PerformanceParser(num_parallel_calls=False), parsers.PerformanceParser(),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
parsers.BenchmarkParser(), parsers.BenchmarkParser(),
......
...@@ -101,16 +101,15 @@ class BaseParser(argparse.ArgumentParser): ...@@ -101,16 +101,15 @@ class BaseParser(argparse.ArgumentParser):
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training. eval metric which should trigger the end of training.
batch_size: Create a flag to specify the global batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, stop_threshold=True, batch_size=True, multi_gpu=True,
multi_gpu=False, num_gpu=True, hooks=True): hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -155,30 +154,16 @@ class BaseParser(argparse.ArgumentParser): ...@@ -155,30 +154,16 @@ class BaseParser(argparse.ArgumentParser):
if batch_size: if batch_size:
self.add_argument( self.add_argument(
"--batch_size", "-bs", type=int, default=32, "--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Global batch size for training and " help="[default: %(default)s] Batch size for training and evaluation.",
"evaluation.",
metavar="<BS>" metavar="<BS>"
) )
assert not (multi_gpu and num_gpu)
if multi_gpu: if multi_gpu:
self.add_argument( self.add_argument(
"--multi_gpu", action="store_true", "--multi_gpu", action="store_true",
help="If set, run across all available GPUs." help="If set, run across all available GPUs."
) )
if num_gpu:
self.add_argument(
"--num_gpus", "-ng",
type=int,
default=1 if tf.test.is_built_with_cuda() else 0,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API. The default is 1 if TensorFlow was"
"built with CUDA, and 0 otherwise.",
metavar="<NG>"
)
if hooks: if hooks:
self.add_argument( self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"], "--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
......
...@@ -26,7 +26,7 @@ class TestParser(argparse.ArgumentParser): ...@@ -26,7 +26,7 @@ class TestParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(TestParser, self).__init__(parents=[ super(TestParser, self).__init__(parents=[
parsers.BaseParser(multi_gpu=True, num_gpu=False), parsers.BaseParser(),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True, parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True), intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True), parsers.ImageModelParser(data_format=True),
......
...@@ -221,8 +221,7 @@ class WideDeepArgParser(argparse.ArgumentParser): ...@@ -221,8 +221,7 @@ class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model.""" """Argument parser for running the wide deep model."""
def __init__(self): def __init__(self):
super(WideDeepArgParser, self).__init__(parents=[ super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
parsers.BaseParser(multi_gpu=False, num_gpu=False)])
self.add_argument( self.add_argument(
'--model_type', '-mt', type=str, default='wide_deep', '--model_type', '-mt', type=str, default='wide_deep',
choices=['wide', 'deep', 'wide_deep'], choices=['wide', 'deep', 'wide_deep'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment