Commit b668f594 authored by pkulzc's avatar pkulzc
Browse files

Sync to latest master.

parents d5fc3ef0 32aa6563
/official/ @nealwu @k-w-w @karmel * @tensorflow/tf-garden-team
/official/ @tensorflow/tf-garden-team @karmel
/research/adversarial_crypto/ @dave-andersen /research/adversarial_crypto/ @dave-andersen
/research/adversarial_text/ @rsepassi @a-dai /research/adversarial_text/ @rsepassi @a-dai
/research/adv_imagenet_models/ @AlexeyKurakin /research/adv_imagenet_models/ @AlexeyKurakin
...@@ -38,6 +39,7 @@ ...@@ -38,6 +39,7 @@
/research/swivel/ @waterson /research/swivel/ @waterson
/research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick /research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick
/research/tcn/ @coreylynch @sermanet /research/tcn/ @coreylynch @sermanet
/research/tensorrt/ @karmel
/research/textsum/ @panyx0718 @peterjliu /research/textsum/ @panyx0718 @peterjliu
/research/transformer/ @daviddao /research/transformer/ @daviddao
/research/video_prediction/ @cbfinn /research/video_prediction/ @cbfinn
......
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import gzip import gzip
import os import os
import shutil import shutil
import tempfile
import numpy as np import numpy as np
from six.moves import urllib from six.moves import urllib
...@@ -67,10 +68,11 @@ def download(directory, filename): ...@@ -67,10 +68,11 @@ def download(directory, filename):
tf.gfile.MakeDirs(directory) tf.gfile.MakeDirs(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filepath = filepath + '.gz' _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath)) print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath) urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out: with gzip.open(zipped_filepath, 'rb') as f_in, \
tf.gfile.Open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath) os.remove(zipped_filepath)
return filepath return filepath
......
...@@ -25,11 +25,12 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -25,11 +25,12 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
class Model(tf.keras.Model): def create_model(data_format):
"""Model to recognize digits in the MNIST dataset. """Model to recognize digits in the MNIST dataset.
Network structure is equivalent to: Network structure is equivalent to:
...@@ -37,60 +38,55 @@ class Model(tf.keras.Model): ...@@ -37,60 +38,55 @@ class Model(tf.keras.Model):
and and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But written as a tf.keras.Model using the tf.layers API. But uses the tf.keras API.
"""
def __init__(self, data_format): Args:
"""Creates a model for classifying a hand-written digit. data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
typically faster on GPUs while 'channels_last' is typically faster on
Args: CPUs. See
data_format: Either 'channels_first' or 'channels_last'. https://www.tensorflow.org/performance/performance_guide#data_formats
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See Returns:
https://www.tensorflow.org/performance/performance_guide#data_formats A tf.keras.Model.
""" """
super(Model, self).__init__() if data_format == 'channels_first':
if data_format == 'channels_first': input_shape = [1, 28, 28]
self._input_shape = [-1, 1, 28, 28] else:
else: assert data_format == 'channels_last'
assert data_format == 'channels_last' input_shape = [28, 28, 1]
self._input_shape = [-1, 28, 28, 1]
l = tf.keras.layers
self.conv1 = tf.layers.Conv2D( max_pool = l.MaxPooling2D(
32, 5, padding='same', data_format=data_format, activation=tf.nn.relu) (2, 2), (2, 2), padding='same', data_format=data_format)
self.conv2 = tf.layers.Conv2D( # The model consists of a sequential chain of layers, so tf.keras.Sequential
64, 5, padding='same', data_format=data_format, activation=tf.nn.relu) # (a subclass of tf.keras.Model) makes for a compact description.
self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) return tf.keras.Sequential(
self.fc2 = tf.layers.Dense(10) [
self.dropout = tf.layers.Dropout(0.4) l.Reshape(input_shape),
self.max_pool2d = tf.layers.MaxPooling2D( l.Conv2D(
(2, 2), (2, 2), padding='same', data_format=data_format) 32,
5,
def __call__(self, inputs, training): padding='same',
"""Add operations to classify a batch of input images. data_format=data_format,
activation=tf.nn.relu),
Args: max_pool,
inputs: A Tensor representing a batch of input images. l.Conv2D(
training: A boolean. Set to True to add operations required only when 64,
training the classifier. 5,
padding='same',
Returns: data_format=data_format,
A logits Tensor with shape [<batch_size>, 10]. activation=tf.nn.relu),
""" max_pool,
y = tf.reshape(inputs, self._input_shape) l.Flatten(),
y = self.conv1(y) l.Dense(1024, activation=tf.nn.relu),
y = self.max_pool2d(y) l.Dropout(0.4),
y = self.conv2(y) l.Dense(10)
y = self.max_pool2d(y) ])
y = tf.layers.flatten(y)
y = self.fc1(y)
y = self.dropout(y, training=training)
return self.fc2(y)
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = Model(params['data_format']) model = create_model(params['data_format'])
image = features image = features
if isinstance(image, dict): if isinstance(image, dict):
image = features['image'] image = features['image']
...@@ -140,8 +136,7 @@ def model_fn(features, labels, mode, params): ...@@ -140,8 +136,7 @@ def model_fn(features, labels, mode, params):
eval_metric_ops={ eval_metric_ops={
'accuracy': 'accuracy':
tf.metrics.accuracy( tf.metrics.accuracy(
labels=labels, labels=labels, predictions=tf.argmax(logits, axis=1)),
predictions=tf.argmax(logits, axis=1)),
}) })
...@@ -231,6 +226,10 @@ def main(argv): ...@@ -231,6 +226,10 @@ def main(argv):
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags.stop_threshold,
eval_results['accuracy']):
break
# Export the model # Export the model
if flags.export_dir is not None: if flags.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
...@@ -245,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser): ...@@ -245,7 +244,7 @@ class MNISTArgParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(MNISTArgParser, self).__init__(parents=[ super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=True, num_gpu=False),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
]) ])
......
...@@ -116,7 +116,7 @@ def main(argv): ...@@ -116,7 +116,7 @@ def main(argv):
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size) test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.Model(data_format) model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum) optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
...@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser): ...@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[ super(MNISTEagerArgParser, self).__init__(parents=[
parsers.BaseParser( parsers.EagerParser(),
epochs_between_evals=False, multi_gpu=False, hooks=False),
parsers.ImageModelParser()]) parsers.ImageModelParser()])
self.add_argument( self.add_argument(
......
...@@ -40,7 +40,7 @@ def random_dataset(): ...@@ -40,7 +40,7 @@ def random_dataset():
def train(defun=False): def train(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
...@@ -51,7 +51,7 @@ def train(defun=False): ...@@ -51,7 +51,7 @@ def train(defun=False):
def evaluate(defun=False): def evaluate(defun=False):
model = mnist.Model(data_format()) model = mnist.create_model(data_format())
dataset = random_dataset() dataset = random_dataset()
if defun: if defun:
model.call = tfe.defun(model.call) model.call = tfe.defun(model.call)
......
...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params): ...@@ -86,7 +86,7 @@ def model_fn(features, labels, mode, params):
if isinstance(image, dict): if isinstance(image, dict):
image = features["image"] image = features["image"]
model = mnist.Model("channels_last") model = mnist.create_model("channels_last")
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
......
...@@ -55,9 +55,17 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75 ...@@ -55,9 +55,17 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75
Other versions and formats: Other versions and formats:
* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv2_imagenet_checkpoint.tar.gz) * [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v2_imagenet_checkpoint.tar.gz)
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v2-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv2_imagenet_frozen_graph.pb) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv1_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv1_imagenet_frozen_graph.pb) ## Compute Devices
Training is accomplished using the DistributionStrategies API. (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md)
The appropriate distribution strategy is chosen based on the `--num_gpus` flag. By default this flag is one if TensorFlow is compiled with CUDA, and zero otherwise.
num_gpus:
+ 0: Use OneDeviceStrategy and train on CPU.
+ 1: Use OneDeviceStrategy and train on GPU.
+ 2+: Use MirroredStrategy (data parallelism) to distribute a batch between devices.
...@@ -103,8 +103,7 @@ def preprocess_image(image, is_training): ...@@ -103,8 +103,7 @@ def preprocess_image(image, is_training):
return image return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset. """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args: Args:
...@@ -112,12 +111,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -112,12 +111,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -125,12 +118,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -125,12 +118,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
filenames = get_filenames(is_training, data_dir) filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'], dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls, parse_record, num_epochs,
examples_per_epoch=num_images, multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -145,7 +136,8 @@ class Cifar10Model(resnet_model.Model): ...@@ -145,7 +136,8 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
Args: Args:
...@@ -156,6 +148,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -156,6 +148,7 @@ class Cifar10Model(resnet_model.Model):
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
Raises: Raises:
ValueError: if invalid resnet_size is chosen ValueError: if invalid resnet_size is chosen
...@@ -180,7 +173,9 @@ class Cifar10Model(resnet_model.Model): ...@@ -180,7 +173,9 @@ class Cifar10Model(resnet_model.Model):
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64, final_size=64,
version=version, version=version,
data_format=data_format) data_format=data_format,
dtype=dtype
)
def cifar10_model_fn(features, labels, mode, params): def cifar10_model_fn(features, labels, mode, params):
...@@ -204,15 +199,21 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -204,15 +199,21 @@ def cifar10_model_fn(features, labels, mode, params):
def loss_filter_fn(_): def loss_filter_fn(_):
return True return True
return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model, return resnet_run_loop.resnet_model_fn(
resnet_size=params['resnet_size'], features=features,
weight_decay=weight_decay, labels=labels,
learning_rate_fn=learning_rate_fn, mode=mode,
momentum=0.9, model_class=Cifar10Model,
data_format=params['data_format'], resnet_size=params['resnet_size'],
version=params['version'], weight_decay=weight_decay,
loss_filter_fn=loss_filter_fn, learning_rate_fn=learning_rate_fn,
multi_gpu=params['multi_gpu']) momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
dtype=params['dtype']
)
def main(argv): def main(argv):
......
...@@ -71,18 +71,19 @@ class BaseTest(tf.test.TestCase): ...@@ -71,18 +71,19 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False): def cifar10_model_fn_helper(self, mode, version, dtype):
input_fn = cifar10_main.get_synth_input_fn() input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE) dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn( spec = cifar10_main.cifar10_model_fn(
features, labels, mode, { features, labels, mode, {
'dtype': dtype,
'resnet_size': 32, 'resnet_size': 32,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'version': version,
'multi_gpu': multi_gpu 'loss_scale': 128 if dtype == tf.float16 else 1,
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -105,44 +106,45 @@ class BaseTest(tf.test.TestCase): ...@@ -105,44 +106,45 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_cifar10_model_fn_train_mode_multi_gpu_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_train_mode_multi_gpu_v2(self): def test_cifar10_model_fn_trainmode__v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v1(self): def test_cifar10_model_fn_eval_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_eval_mode_v2(self): def test_cifar10_model_fn_eval_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v1(self): def test_cifar10_model_fn_predict_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def test_cifar10model_shape(self): def _test_cifar10model_shape(self, version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
for version in (1, 2): model = cifar10_main.Cifar10Model(32, data_format='channels_last',
model = cifar10_main.Cifar10Model( num_classes=num_classes, version=version)
32, data_format='channels_last', num_classes=num_classes, fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
version=version) output = model(fake_input, training=True)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) self.assertAllEqual(output.shape, (batch_size, num_classes))
output = model(fake_input, training=True)
def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1)
self.assertAllEqual(output.shape, (batch_size, num_classes)) def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2)
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
......
...@@ -154,8 +154,7 @@ def parse_record(raw_record, is_training): ...@@ -154,8 +154,7 @@ def parse_record(raw_record, is_training):
return image, label return image, label
def input_fn(is_training, data_dir, batch_size, num_epochs=1, def input_fn(is_training, data_dir, batch_size, num_epochs=1):
num_parallel_calls=1, multi_gpu=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -163,12 +162,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -163,12 +162,6 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
data_dir: The directory containing the input data. data_dir: The directory containing the input data.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers, and can be removed
when that is handled directly by Estimator.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -180,15 +173,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -180,15 +173,13 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
# Shuffle the input files # Shuffle the input files
dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']
# Convert to individual records # Convert to individual records
dataset = dataset.flat_map(tf.data.TFRecordDataset) dataset = dataset.flat_map(tf.data.TFRecordDataset)
return resnet_run_loop.process_record_dataset( return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record, dataset, is_training, batch_size, _SHUFFLE_BUFFER, parse_record,
num_epochs, num_parallel_calls, examples_per_epoch=num_images, num_epochs
multi_gpu=multi_gpu) )
def get_synth_input_fn(): def get_synth_input_fn():
...@@ -203,7 +194,8 @@ class ImagenetModel(resnet_model.Model): ...@@ -203,7 +194,8 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
Args: Args:
...@@ -214,6 +206,7 @@ class ImagenetModel(resnet_model.Model): ...@@ -214,6 +206,7 @@ class ImagenetModel(resnet_model.Model):
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
""" """
# For bigger models, we want to use "bottleneck" layers # For bigger models, we want to use "bottleneck" layers
...@@ -239,7 +232,9 @@ class ImagenetModel(resnet_model.Model): ...@@ -239,7 +232,9 @@ class ImagenetModel(resnet_model.Model):
block_strides=[1, 2, 2, 2], block_strides=[1, 2, 2, 2],
final_size=final_size, final_size=final_size,
version=version, version=version,
data_format=data_format) data_format=data_format,
dtype=dtype
)
def _get_block_sizes(resnet_size): def _get_block_sizes(resnet_size):
...@@ -283,15 +278,21 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -283,15 +278,21 @@ def imagenet_model_fn(features, labels, mode, params):
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
return resnet_run_loop.resnet_model_fn(features, labels, mode, ImagenetModel, return resnet_run_loop.resnet_model_fn(
resnet_size=params['resnet_size'], features=features,
weight_decay=1e-4, labels=labels,
learning_rate_fn=learning_rate_fn, mode=mode,
momentum=0.9, model_class=ImagenetModel,
data_format=params['data_format'], resnet_size=params['resnet_size'],
version=params['version'], weight_decay=1e-4,
loss_filter_fn=None, learning_rate_fn=learning_rate_fn,
multi_gpu=params['multi_gpu']) momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
dtype=params['dtype']
)
def main(argv): def main(argv):
......
...@@ -36,7 +36,7 @@ class BaseTest(tf.test.TestCase): ...@@ -36,7 +36,7 @@ class BaseTest(tf.test.TestCase):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used.""" """Returns the expected dimensions depending on if a GPU is being used."""
...@@ -50,22 +50,24 @@ class BaseTest(tf.test.TestCase): ...@@ -50,22 +50,24 @@ class BaseTest(tf.test.TestCase):
graph = tf.Graph() graph = tf.Graph()
with graph.as_default(), self.test_session( with graph.as_default(), self.test_session(
use_gpu=with_gpu, force_gpu=with_gpu): graph=graph, use_gpu=with_gpu, force_gpu=with_gpu):
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
resnet_size, resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last', data_format='channels_first' if with_gpu else 'channels_last',
version=version) version=version,
dtype=dtype
)
inputs = tf.random_uniform([1, 224, 224, 3]) inputs = tf.random_uniform([1, 224, 224, 3])
output = model(inputs, training=True) output = model(inputs, training=True)
initial_conv = graph.get_tensor_by_name('initial_conv:0') initial_conv = graph.get_tensor_by_name('resnet_model/initial_conv:0')
max_pool = graph.get_tensor_by_name('initial_max_pool:0') max_pool = graph.get_tensor_by_name('resnet_model/initial_max_pool:0')
block_layer1 = graph.get_tensor_by_name('block_layer1:0') block_layer1 = graph.get_tensor_by_name('resnet_model/block_layer1:0')
block_layer2 = graph.get_tensor_by_name('block_layer2:0') block_layer2 = graph.get_tensor_by_name('resnet_model/block_layer2:0')
block_layer3 = graph.get_tensor_by_name('block_layer3:0') block_layer3 = graph.get_tensor_by_name('resnet_model/block_layer3:0')
block_layer4 = graph.get_tensor_by_name('block_layer4:0') block_layer4 = graph.get_tensor_by_name('resnet_model/block_layer4:0')
reduce_mean = graph.get_tensor_by_name('final_reduce_mean:0') reduce_mean = graph.get_tensor_by_name('resnet_model/final_reduce_mean:0')
dense = graph.get_tensor_by_name('final_dense:0') dense = graph.get_tensor_by_name('resnet_model/final_dense:0')
self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112))) self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112)))
self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56))) self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56)))
...@@ -88,6 +90,12 @@ class BaseTest(tf.test.TestCase): ...@@ -88,6 +90,12 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self): def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1) self.tensor_shapes_helper(18, version=1)
...@@ -172,7 +180,7 @@ class BaseTest(tf.test.TestCase): ...@@ -172,7 +180,7 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, multi_gpu=False): def resnet_model_fn_helper(self, mode, version, dtype):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() tf.train.create_global_step()
...@@ -182,11 +190,12 @@ class BaseTest(tf.test.TestCase): ...@@ -182,11 +190,12 @@ class BaseTest(tf.test.TestCase):
features, labels = iterator.get_next() features, labels = iterator.get_next()
spec = imagenet_main.imagenet_model_fn( spec = imagenet_main.imagenet_model_fn(
features, labels, mode, { features, labels, mode, {
'dtype': dtype,
'resnet_size': 50, 'resnet_size': 50,
'data_format': 'channels_last', 'data_format': 'channels_last',
'batch_size': _BATCH_SIZE, 'batch_size': _BATCH_SIZE,
'version': version, 'version': version,
'multi_gpu': multi_gpu, 'loss_scale': 128 if dtype == tf.float16 else 1,
}) })
predictions = spec.predictions predictions = spec.predictions
...@@ -209,43 +218,47 @@ class BaseTest(tf.test.TestCase): ...@@ -209,43 +218,47 @@ class BaseTest(tf.test.TestCase):
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2)
def test_resnet_model_fn_train_mode_multi_gpu_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_train_mode_multi_gpu_v2(self): def test_resnet_model_fn_train_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2, self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=2,
multi_gpu=True) dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v1(self): def test_resnet_model_fn_eval_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=1,
dtype=tf.float32)
def test_resnet_model_fn_eval_mode_v2(self): def test_resnet_model_fn_eval_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.EVAL, version=2,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v1(self): def test_resnet_model_fn_predict_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=1,
dtype=tf.float32)
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2,
dtype=tf.float32)
def test_imagenetmodel_shape(self): def _test_imagenetmodel_shape(self, version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
for version in (1, 2): model = imagenet_main.ImagenetModel(
model = imagenet_main.ImagenetModel( 50, data_format='channels_last', num_classes=num_classes,
50, data_format='channels_last', num_classes=num_classes, version=version)
version=version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3]) fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True) output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(version=1)
def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(version=2)
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
...@@ -283,5 +296,6 @@ class BaseTest(tf.test.TestCase): ...@@ -283,5 +296,6 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-v', '2', '-rs', '200'] extra_flags=['-v', '2', '-rs', '200']
) )
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -36,6 +36,9 @@ import tensorflow as tf ...@@ -36,6 +36,9 @@ import tensorflow as tf
_BATCH_NORM_DECAY = 0.997 _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2 DEFAULT_VERSION = 2
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES
################################################################################ ################################################################################
...@@ -351,7 +354,8 @@ class Model(object): ...@@ -351,7 +354,8 @@ class Model(object):
kernel_size, kernel_size,
conv_stride, first_pool_size, first_pool_stride, conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_sizes, block_strides, second_pool_size, second_pool_stride, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None): final_size, version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image. """Creates a model for classifying an image.
Args: Args:
...@@ -379,6 +383,8 @@ class Model(object): ...@@ -379,6 +383,8 @@ class Model(object):
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified
tf.float32 is used.
Raises: Raises:
ValueError: if invalid version is selected. ValueError: if invalid version is selected.
...@@ -406,6 +412,9 @@ class Model(object): ...@@ -406,6 +412,9 @@ class Model(object):
else: else:
self.block_fn = _building_block_v2 self.block_fn = _building_block_v2
if dtype not in ALLOWED_TYPES:
raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))
self.data_format = data_format self.data_format = data_format
self.num_classes = num_classes self.num_classes = num_classes
self.num_filters = num_filters self.num_filters = num_filters
...@@ -418,6 +427,61 @@ class Model(object): ...@@ -418,6 +427,61 @@ class Model(object):
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.block_strides = block_strides self.block_strides = block_strides
self.final_size = final_size self.final_size = final_size
self.dtype = dtype
def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE,
*args, **kwargs):
"""Creates variables in fp32, then casts to fp16 if necessary.
This function is a custom getter. A custom getter is a function with the
same signature as tf.get_variable, except it has an additional getter
parameter. Custom getters can be passed as the `custom_getter` parameter of
tf.variable_scope. Then, tf.get_variable will call the custom getter,
instead of directly getting a variable itself. This can be used to change
the types of variables that are retrieved with tf.get_variable.
The `getter` parameter is the underlying variable getter, that would have
been called if no custom getter was used. Custom getters typically get a
variable with `getter`, then modify it in some way.
This custom getter will create an fp32 variable. If a low precision
(e.g. float16) variable was requested it will then cast the variable to the
requested dtype. The reason we do not directly create variables in low
precision dtypes is that applying small gradients to such variables may
cause the variable not to change.
Args:
getter: The underlying variable getter, that has the same signature as
tf.get_variable and returns a variable.
name: The name of the variable to get.
shape: The shape of the variable to get.
dtype: The dtype of the variable to get. Note that if this is a low
precision dtype, the variable will be created as a tf.float32 variable,
then cast to the appropriate dtype
*args: Additional arguments to pass unmodified to getter.
**kwargs: Additional keyword arguments to pass unmodified to getter.
Returns:
A variable which is cast to fp16 if necessary.
"""
if dtype in CASTABLE_TYPES:
var = getter(name, shape, tf.float32, *args, **kwargs)
return tf.cast(var, dtype=dtype, name=name + '_cast')
else:
return getter(name, shape, dtype, *args, **kwargs)
def _model_variable_scope(self):
"""Returns a variable scope that the model should be created under.
If self.dtype is a castable type, model variable will be created in fp32
then cast to self.dtype before being used.
Returns:
A variable scope for the model.
"""
return tf.variable_scope('resnet_model',
custom_getter=self._custom_dtype_getter)
def __call__(self, inputs, training): def __call__(self, inputs, training):
"""Add operations to classify a batch of input images. """Add operations to classify a batch of input images.
...@@ -431,46 +495,46 @@ class Model(object): ...@@ -431,46 +495,46 @@ class Model(object):
A logits Tensor with shape [<batch_size>, self.num_classes]. A logits Tensor with shape [<batch_size>, self.num_classes].
""" """
if self.data_format == 'channels_first': with self._model_variable_scope():
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW). if self.data_format == 'channels_first':
# This provides a large performance boost on GPU. See # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# https://www.tensorflow.org/performance/performance_guide#data_formats # This provides a large performance boost on GPU. See
inputs = tf.transpose(inputs, [0, 3, 1, 2]) # https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv2d_fixed_padding(
inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size, inputs = conv2d_fixed_padding(
strides=self.conv_stride, data_format=self.data_format) inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size,
inputs = tf.identity(inputs, 'initial_conv') strides=self.conv_stride, data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_conv')
if self.first_pool_size:
inputs = tf.layers.max_pooling2d( if self.first_pool_size:
inputs=inputs, pool_size=self.first_pool_size, inputs = tf.layers.max_pooling2d(
strides=self.first_pool_stride, padding='SAME', inputs=inputs, pool_size=self.first_pool_size,
data_format=self.data_format) strides=self.first_pool_stride, padding='SAME',
inputs = tf.identity(inputs, 'initial_max_pool') data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_max_pool')
for i, num_blocks in enumerate(self.block_sizes):
num_filters = self.num_filters * (2**i) for i, num_blocks in enumerate(self.block_sizes):
inputs = block_layer( num_filters = self.num_filters * (2**i)
inputs=inputs, filters=num_filters, bottleneck=self.bottleneck, inputs = block_layer(
block_fn=self.block_fn, blocks=num_blocks, inputs=inputs, filters=num_filters, bottleneck=self.bottleneck,
strides=self.block_strides[i], training=training, block_fn=self.block_fn, blocks=num_blocks,
name='block_layer{}'.format(i + 1), data_format=self.data_format) strides=self.block_strides[i], training=training,
name='block_layer{}'.format(i + 1), data_format=self.data_format)
inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs) inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs)
# The current top layer has shape
# `batch_size x pool_size x pool_size x final_size`. # The current top layer has shape
# ResNet does an Average Pooling layer over pool_size, # `batch_size x pool_size x pool_size x final_size`.
# but that is the same as doing a reduce_mean. We do a reduce_mean # ResNet does an Average Pooling layer over pool_size,
# here because it performs better than AveragePooling2D. # but that is the same as doing a reduce_mean. We do a reduce_mean
axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] # here because it performs better than AveragePooling2D.
inputs = tf.reduce_mean(inputs, axes, keepdims=True) axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
inputs = tf.identity(inputs, 'final_reduce_mean') inputs = tf.reduce_mean(inputs, axes, keepdims=True)
inputs = tf.identity(inputs, 'final_reduce_mean')
inputs = tf.reshape(inputs, [-1, self.final_size])
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.reshape(inputs, [-1, self.final_size])
inputs = tf.identity(inputs, 'final_dense') inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
...@@ -33,14 +33,14 @@ from official.utils.arg_parsers import parsers ...@@ -33,14 +33,14 @@ from official.utils.arg_parsers import parsers
from official.utils.export import export from official.utils.export import export
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import model_helpers
################################################################################ ################################################################################
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn, num_epochs=1, num_parallel_calls=1, parse_record_fn, num_epochs=1):
examples_per_epoch=0, multi_gpu=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -53,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -53,19 +53,11 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset. num_epochs: The number of epochs to repeat the dataset.
num_parallel_calls: The number of records that are processed in parallel.
This can be optimized per data set but for generally homogeneous data
sets, should be approximately the number of available CPU cores.
examples_per_epoch: The number of examples in the current set that
are processed each epoch. Note that this is only used for multi-GPU mode,
and only to handle what will eventually be handled inside of Estimator.
multi_gpu: Whether this is run multi-GPU. Note that this is only required
currently to handle the batch leftovers (see below), and can be removed
when that is handled directly by Estimator.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
""" """
# We prefetch a batch at a time, This can help smooth out the time taken to # We prefetch a batch at a time, This can help smooth out the time taken to
# load input files as we go through shuffling and processing. # load input files as we go through shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size) dataset = dataset.prefetch(buffer_size=batch_size)
...@@ -78,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer, ...@@ -78,29 +70,22 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs. # dataset for the appropriate number of epochs.
dataset = dataset.repeat(num_epochs) dataset = dataset.repeat(num_epochs)
# Currently, if we are using multiple GPUs, we can't pass in uneven batches. # Parse the raw records into images and labels. Testing has shown that setting
# (For example, if we have 4 GPUs, the number of examples in each batch # num_parallel_batches > 1 produces no improvement in throughput, since
# must be divisible by 4.) We already ensured this for the batch_size, but # batch_size is almost always much greater than the number of CPU cores.
# we have to additionally ensure that any "leftover" examples-- the remainder dataset = dataset.apply(
# examples (total examples % batch_size) that get called a batch for the very tf.contrib.data.map_and_batch(
# last batch of an epoch-- do not raise an error when we try to split them lambda value: parse_record_fn(value, is_training),
# over the GPUs. This will likely be handled by Estimator during replication batch_size=batch_size,
# in the future, but for now, we just drop the leftovers here. num_parallel_batches=1))
if multi_gpu:
total_examples = num_epochs * examples_per_epoch
dataset = dataset.take(batch_size * (total_examples // batch_size))
# Parse the raw records into images and labels
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
dataset = dataset.batch(batch_size)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the # background all of the above processing work and keep it out of the
# critical training path. # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
dataset = dataset.prefetch(1) # allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return dataset return dataset
...@@ -122,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes): ...@@ -122,7 +107,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
def input_fn(is_training, data_dir, batch_size, *args): # pylint: disable=unused-argument def input_fn(is_training, data_dir, batch_size, *args, **kwargs): # pylint: disable=unused-argument
images = tf.zeros((batch_size, height, width, num_channels), tf.float32) images = tf.zeros((batch_size, height, width, num_channels), tf.float32)
labels = tf.zeros((batch_size, num_classes), tf.int32) labels = tf.zeros((batch_size, num_classes), tf.int32)
return tf.data.Dataset.from_tensors((images, labels)).repeat() return tf.data.Dataset.from_tensors((images, labels)).repeat()
...@@ -170,7 +155,8 @@ def learning_rate_with_decay( ...@@ -170,7 +155,8 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_filter_fn=None, multi_gpu=False): data_format, version, loss_scale, loss_filter_fn=None,
dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers Initializes the ResnetModel representing the model layers
...@@ -196,12 +182,13 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -196,12 +182,13 @@ def resnet_model_fn(features, labels, mode, model_class,
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for dtype: the TensorFlow dtype to use for calculations.
data-parallel distribution across multiple GPUs.
Returns: Returns:
EstimatorSpec parameterized according to the input params and the EstimatorSpec parameterized according to the input params and the
...@@ -211,9 +198,17 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -211,9 +198,17 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images # Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
model = model_class(resnet_size, data_format, version=version) features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, version=version, dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
# This acts as a no-op if the logits are already in fp32 (provided logits are
# not a SparseTensor). If dtype is is low precision, logits must be cast to
# fp32 for numerical stability.
logits = tf.cast(logits, tf.float32)
predictions = { predictions = {
'classes': tf.argmax(logits, axis=1), 'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
...@@ -244,7 +239,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -244,7 +239,8 @@ def resnet_model_fn(features, labels, mode, model_class,
# Add weight decay to the loss. # Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n( l2_loss = weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables() # loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
if loss_filter_fn(v.name)]) if loss_filter_fn(v.name)])
tf.summary.scalar('l2_loss', l2_loss) tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss loss = cross_entropy + l2_loss
...@@ -260,19 +256,36 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -260,19 +256,36 @@ def resnet_model_fn(features, labels, mode, model_class,
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, learning_rate=learning_rate,
momentum=momentum) momentum=momentum
)
# If we are running multi-GPU, we need to wrap the optimizer. if loss_scale != 1:
if multi_gpu: # When computing fp16 gradients, often intermediate tensor values are
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) # so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
minimize_op = optimizer.minimize(loss, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(optimizer.minimize(loss, global_step), update_ops) train_op = tf.group(minimize_op, update_ops)
else: else:
train_op = None train_op = None
accuracy = tf.metrics.accuracy( if not tf.contrib.distribute.has_distribution_strategy():
tf.argmax(labels, axis=1), predictions['classes']) accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes'])
else:
# Metrics are currently not compatible with distribution strategies during
# training. This does not affect the overall performance of the model.
accuracy = (tf.no_op(), tf.constant(0))
metrics = {'accuracy': accuracy} metrics = {'accuracy': accuracy}
# Create a tensor named train_accuracy for logging purposes # Create a tensor named train_accuracy for logging purposes
...@@ -287,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -287,34 +300,35 @@ def resnet_model_fn(features, labels, mode, model_class,
eval_metric_ops=metrics) eval_metric_ops=metrics)
def validate_batch_size_for_multi_gpu(batch_size): def per_device_batch_size(batch_size, num_gpus):
"""For multi-gpu, batch-size must be a multiple of the number of GPUs. """For multi-gpu, batch-size must be a multiple of the number of GPUs.
Note that this should eventually be handled by replicate_model_fn Note that this should eventually be handled by DistributionStrategies
directly. Multi-GPU support is currently experimental, however, directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place. so doing the work here until that feature is in place.
Args: Args:
batch_size: the number of examples processed in each training batch. batch_size: Global batch size to be divided among devices. This should be
equal to num_gpus times the single-GPU batch_size for multi-gpu training.
num_gpus: How many GPUs are used with DistributionStrategies.
Returns:
Batch size per device.
Raises: Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid. ValueError: if batch_size is not divisible by number of devices
""" """
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top if num_gpus <= 1:
return batch_size
local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
remainder = batch_size % num_gpus remainder = batch_size % num_gpus
if remainder: if remainder:
err = ('When running with multiple GPUs, batch size ' err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. ' 'must be a multiple of the number of available GPUs. Found {} '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.' 'GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder) ).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err) raise ValueError(err)
return int(batch_size / num_gpus)
def resnet_main(flags, model_function, input_function, shape=None): def resnet_main(flags, model_function, input_function, shape=None):
...@@ -335,16 +349,6 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -335,16 +349,6 @@ def resnet_main(flags, model_function, input_function, shape=None):
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_function,
loss_reduction=tf.losses.Reduction.MEAN)
# Create session config based on values of inter_op_parallelism_threads and # Create session config based on values of inter_op_parallelism_threads and
# intra_op_parallelism_threads. Note that we default to having # intra_op_parallelism_threads. Note that we default to having
# allow_soft_placement = True, which is required for multi-GPU and not # allow_soft_placement = True, which is required for multi-GPU and not
...@@ -354,22 +358,32 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -354,22 +358,32 @@ def resnet_main(flags, model_function, input_function, shape=None):
intra_op_parallelism_threads=flags.intra_op_parallelism_threads, intra_op_parallelism_threads=flags.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
# Set up a RunConfig to save checkpoint and set session config. if flags.num_gpus == 0:
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9, distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
session_config=session_config) elif flags.num_gpus == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
distribution = tf.contrib.distribute.MirroredStrategy(
num_gpus=flags.num_gpus
)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config, model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={ params={
'resnet_size': flags.resnet_size, 'resnet_size': flags.resnet_size,
'data_format': flags.data_format, 'data_format': flags.data_format,
'batch_size': flags.batch_size, 'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu,
'version': flags.version, 'version': flags.version,
'loss_scale': flags.loss_scale,
'dtype': flags.dtype
}) })
if flags.benchmark_log_dir is not None: if flags.benchmark_log_dir is not None:
benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir) benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir)
benchmark_logger.log_run_info("resnet") benchmark_logger.log_run_info('resnet')
else: else:
benchmark_logger = None benchmark_logger = None
...@@ -382,9 +396,12 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -382,9 +396,12 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting a training cycle.') print('Starting a training cycle.')
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(
flags.epochs_between_evals, is_training=True,
flags.num_parallel_calls, flags.multi_gpu) data_dir=flags.data_dir,
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=flags.epochs_between_evals,
)
classifier.train(input_fn=input_fn_train, hooks=train_hooks, classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps) max_steps=flags.max_train_steps)
...@@ -392,8 +409,12 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -392,8 +409,12 @@ def resnet_main(flags, model_function, input_function, shape=None):
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size, return input_function(
1, flags.num_parallel_calls, flags.multi_gpu) is_training=False,
data_dir=flags.data_dir,
batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus),
num_epochs=1,
)
# flags.max_train_steps is generally associated with testing and profiling. # flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will # As a result it is frequently called with synthetic data, which will
...@@ -408,32 +429,24 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -408,32 +429,24 @@ def resnet_main(flags, model_function, input_function, shape=None):
if benchmark_logger: if benchmark_logger:
benchmark_logger.log_estimator_evaluation_result(eval_results) benchmark_logger.log_estimator_evaluation_result(eval_results)
if flags.export_dir is not None: if model_helpers.past_stop_threshold(
warn_on_multi_gpu_export(flags.multi_gpu) flags.stop_threshold, eval_results['accuracy']):
break
if flags.export_dir is not None:
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags.batch_size) shape, batch_size=flags.batch_size)
classifier.export_savedmodel(flags.export_dir, input_receiver_fn) classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False):
"""For the time being, multi-GPU mode does not play nicely with exporting."""
if multi_gpu:
tf.logging.warning(
'You are exporting a SavedModel while in multi-GPU mode. Note that '
'the resulting SavedModel will require the same GPUs be available.'
'If you wish to serve the SavedModel from a different device, '
'try exporting the SavedModel with multi-GPU mode turned off.')
class ResnetArgParser(argparse.ArgumentParser): class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.""" """Arguments for configuring and running a Resnet Model."""
def __init__(self, resnet_size_choices=None): def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__(parents=[ super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=False),
parsers.PerformanceParser(), parsers.PerformanceParser(num_parallel_calls=False),
parsers.ImageModelParser(), parsers.ImageModelParser(),
parsers.ExportParser(), parsers.ExportParser(),
parsers.BenchmarkParser(), parsers.BenchmarkParser(),
...@@ -451,3 +464,12 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -451,3 +464,12 @@ class ResnetArgParser(argparse.ArgumentParser):
help='[default: %(default)s] The size of the ResNet model to use.', help='[default: %(default)s] The size of the ResNet model to use.',
metavar='<RS>' if resnet_size_choices is None else None metavar='<RS>' if resnet_size_choices is None else None
) )
def parse_args(self, args=None, namespace=None):
args = super(ResnetArgParser, self).parse_args(
args=args, namespace=namespace)
# handle coupling between dtype and loss_scale
parsers.parse_dtype_info(args)
return args
...@@ -58,9 +58,37 @@ from __future__ import absolute_import ...@@ -58,9 +58,37 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import tensorflow as tf
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def parse_dtype_info(flags):
"""Convert dtype string to tf dtype, and set loss_scale default as needed.
Args:
flags: namespace object returned by arg parser.
Raises:
ValueError: If an invalid dtype is provided.
"""
if flags.dtype in (i[0] for i in DTYPE_MAP.values()):
return # Make function idempotent
try:
flags.dtype, default_loss_scale = DTYPE_MAP[flags.dtype]
except KeyError:
raise ValueError("Invalid dtype: {}".format(flags.dtype))
flags.loss_scale = flags.loss_scale or default_loss_scale
class BaseParser(argparse.ArgumentParser): class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models. """Parser to contain flags which will be nearly universal across models.
...@@ -71,14 +99,18 @@ class BaseParser(argparse.ArgumentParser): ...@@ -71,14 +99,18 @@ class BaseParser(argparse.ArgumentParser):
model_dir: Create a flag for specifying the model file directory. model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs. train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size. stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the global batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, batch_size=True, train_epochs=True, epochs_between_evals=True,
multi_gpu=True, hooks=True): stop_threshold=True, batch_size=True,
multi_gpu=False, num_gpu=True, hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -111,19 +143,42 @@ class BaseParser(argparse.ArgumentParser): ...@@ -111,19 +143,42 @@ class BaseParser(argparse.ArgumentParser):
metavar="<EBE>" metavar="<EBE>"
) )
if stop_threshold:
self.add_argument(
"--stop_threshold", "-st", type=float, default=None,
help="[default: %(default)s] If passed, training will stop at "
"the earlier of train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold.",
metavar="<ST>"
)
if batch_size: if batch_size:
self.add_argument( self.add_argument(
"--batch_size", "-bs", type=int, default=32, "--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.", help="[default: %(default)s] Global batch size for training and "
"evaluation.",
metavar="<BS>" metavar="<BS>"
) )
assert not (multi_gpu and num_gpu)
if multi_gpu: if multi_gpu:
self.add_argument( self.add_argument(
"--multi_gpu", action="store_true", "--multi_gpu", action="store_true",
help="If set, run across all available GPUs." help="If set, run across all available GPUs."
) )
if num_gpu:
self.add_argument(
"--num_gpus", "-ng",
type=int,
default=1 if tf.test.is_built_with_cuda() else 0,
help="[default: %(default)s] How many GPUs to use with the "
"DistributionStrategies API. The default is 1 if TensorFlow was"
"built with CUDA, and 0 otherwise.",
metavar="<NG>"
)
if hooks: if hooks:
self.add_argument( self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"], "--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
...@@ -148,7 +203,8 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -148,7 +203,8 @@ class PerformanceParser(argparse.ArgumentParser):
""" """
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True, def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True): intra_op=True, use_synthetic_data=True, max_train_steps=True,
dtype=True):
super(PerformanceParser, self).__init__(add_help=add_help) super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls: if num_parallel_calls:
...@@ -201,6 +257,31 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -201,6 +257,31 @@ class PerformanceParser(argparse.ArgumentParser):
metavar="<MTS>" metavar="<MTS>"
) )
if dtype:
self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=list(DTYPE_MAP.keys()),
help="[default: %(default)s] {%(choices)s} The TensorFlow datatype "
"used for calculations. Variables may be cast to a higher"
"precision on a case-by-case basis for numerical stability.",
metavar="<DT>"
)
self.add_argument(
"--loss_scale", "-ls",
type=int,
help="[default: %(default)s] The amount to scale the loss by when "
"the model is run. Before gradients are computed, the loss is "
"multiplied by the loss scale, making all gradients loss_scale "
"times larger. To adjust for this, gradients are divided by the "
"loss scale before being applied to variables. This is "
"mathematically equivalent to training without a loss scale, "
"but the loss scale helps avoid some intermediate gradients "
"from underflowing to zero. If not provided the default for "
"fp16 is 128 and 1 for all other dtypes.",
)
class ImageModelParser(argparse.ArgumentParser): class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior. """Default parser for specification image specific behavior.
...@@ -291,3 +372,15 @@ class BenchmarkParser(argparse.ArgumentParser): ...@@ -291,3 +372,15 @@ class BenchmarkParser(argparse.ArgumentParser):
" benchmark metric information will be uploaded.", " benchmark metric information will be uploaded.",
metavar="<BMT>" metavar="<BMT>"
) )
class EagerParser(BaseParser):
"""Remove options not relevant for Eager from the BaseParser."""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, batch_size=True):
super(EagerParser, self).__init__(
add_help=add_help, data_dir=data_dir, model_dir=model_dir,
train_epochs=train_epochs, epochs_between_evals=False,
stop_threshold=False, batch_size=batch_size, multi_gpu=False,
hooks=False)
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import unittest import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
...@@ -25,7 +26,7 @@ class TestParser(argparse.ArgumentParser): ...@@ -25,7 +26,7 @@ class TestParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(TestParser, self).__init__(parents=[ super(TestParser, self).__init__(parents=[
parsers.BaseParser(), parsers.BaseParser(multi_gpu=True, num_gpu=False),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True, parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True), intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True), parsers.ImageModelParser(data_format=True),
...@@ -83,6 +84,24 @@ class BaseTester(unittest.TestCase): ...@@ -83,6 +84,24 @@ class BaseTester(unittest.TestCase):
assert namespace.multi_gpu assert namespace.multi_gpu
assert namespace.use_synthetic_data assert namespace.use_synthetic_data
def test_parse_dtype_info(self):
parser = TestParser()
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
["fp32", tf.float32, 1]]:
args = parser.parse_args(["--dtype", dtype_str])
parsers.parse_dtype_info(args)
assert args.dtype == tf_dtype
assert args.loss_scale == loss_scale
args = parser.parse_args(["--dtype", dtype_str, "--loss_scale", "5"])
parsers.parse_dtype_info(args)
assert args.loss_scale == 5
with self.assertRaises(SystemExit):
parser.parse_args(["--dtype", "int8"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous functions that can be called by models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import tensorflow as tf
def past_stop_threshold(stop_threshold, eval_metric):
"""Return a boolean representing whether a model should be stopped.
Args:
stop_threshold: float, the threshold above which a model should stop
training.
eval_metric: float, the current value of the relevant metric to check.
Returns:
True if training should stop, False otherwise.
Raises:
ValueError: if either stop_threshold or eval_metric is not a number
"""
if stop_threshold is None:
return False
if not isinstance(stop_threshold, numbers.Number):
raise ValueError("Threshold for checking stop conditions must be a number.")
if not isinstance(eval_metric, numbers.Number):
raise ValueError("Eval metric being checked against stop conditions "
"must be a number.")
if eval_metric >= stop_threshold:
tf.logging.info(
"Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric))
return True
return False
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Tests for Model Helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import model_helpers
class PastStopThresholdTest(tf.test.TestCase):
"""Tests for past_stop_threshold."""
def test_past_stop_threshold(self):
"""Tests for normal operating conditions."""
self.assertTrue(model_helpers.past_stop_threshold(0.54, 1))
self.assertTrue(model_helpers.past_stop_threshold(54, 100))
self.assertFalse(model_helpers.past_stop_threshold(0.54, 0.1))
self.assertFalse(model_helpers.past_stop_threshold(-0.54, -1.5))
self.assertTrue(model_helpers.past_stop_threshold(-0.54, 0))
self.assertTrue(model_helpers.past_stop_threshold(0, 0))
self.assertTrue(model_helpers.past_stop_threshold(0.54, 0.54))
def test_past_stop_threshold_none_false(self):
"""Tests that check None returns false."""
self.assertFalse(model_helpers.past_stop_threshold(None, -1.5))
self.assertFalse(model_helpers.past_stop_threshold(None, None))
self.assertFalse(model_helpers.past_stop_threshold(None, 1.5))
# Zero should be okay, though.
self.assertTrue(model_helpers.past_stop_threshold(0, 1.5))
def test_past_stop_threshold_not_number(self):
"""Tests for error conditions."""
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", 1)
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", tf.constant(5))
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", "another")
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(0, None)
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(0.7, "str")
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(tf.constant(4), None)
if __name__ == "__main__":
tf.test.main()
...@@ -170,7 +170,7 @@ class BaseTest(tf.test.TestCase): ...@@ -170,7 +170,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison. # Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString() graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph") expected_file = os.path.join(data_dir, "expected_graph")
with open(expected_file, "wb") as f: with tf.gfile.Open(expected_file, "wb") as f:
f.write(graph_bytes) f.write(graph_bytes)
with graph.as_default(): with graph.as_default():
...@@ -191,10 +191,10 @@ class BaseTest(tf.test.TestCase): ...@@ -191,10 +191,10 @@ class BaseTest(tf.test.TestCase):
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with open(os.path.join(data_dir, "results.json"), "wt") as f: with tf.gfile.Open(os.path.join(data_dir, "results.json"), "w") as f:
json.dump(results, f) json.dump(results, f)
with open(os.path.join(data_dir, "tf_version.json"), "wt") as f: with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "w") as f:
json.dump([tf.VERSION, tf.GIT_VERSION], f) json.dump([tf.VERSION, tf.GIT_VERSION], f)
def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function): def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
...@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison. # Serialize graph for comparison.
graph_bytes = graph.as_graph_def().SerializeToString() graph_bytes = graph.as_graph_def().SerializeToString()
expected_file = os.path.join(data_dir, "expected_graph") expected_file = os.path.join(data_dir, "expected_graph")
with open(expected_file, "rb") as f: with tf.gfile.Open(expected_file, "rb") as f:
expected_graph_bytes = f.read() expected_graph_bytes = f.read()
# The serialization is non-deterministic byte-for-byte. Instead there is # The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for # a utility which evaluates the semantics of the two graphs to test for
...@@ -231,7 +231,7 @@ class BaseTest(tf.test.TestCase): ...@@ -231,7 +231,7 @@ class BaseTest(tf.test.TestCase):
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
saver = tf.train.Saver() saver = tf.train.Saver()
with open(os.path.join(data_dir, "tf_version.json"), "rt") as f: with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "r") as f:
tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence tf_version_reference, tf_git_version_reference = json.load(f) # pylint: disable=unpacking-non-sequence
tf_version_comparison = "" tf_version_comparison = ""
...@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase): ...@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
eval_results = [op.eval() for op in ops_to_eval] eval_results = [op.eval() for op in ops_to_eval]
if correctness_function is not None: if correctness_function is not None:
results = correctness_function(*eval_results) results = correctness_function(*eval_results)
with open(os.path.join(data_dir, "results.json"), "rt") as f: with tf.gfile.Open(os.path.join(data_dir, "results.json"), "r") as f:
expected_results = json.load(f) expected_results = json.load(f)
self.assertAllClose(results, expected_results) self.assertAllClose(results, expected_results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment