Unverified Commit fbb27cf3 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Add fp16 support to official ResNet. (#3687)

* Add fp16 support to resnet.

* address PR comments

* add dtype checking to model definition

* delint

* more PR comments

* few more tweaks

* update resnet checkpoints
parent 6741cfce
...@@ -55,9 +55,7 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75 ...@@ -55,9 +55,7 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75
Other versions and formats: Other versions and formats:
* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv2_imagenet_checkpoint.tar.gz) * [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v2_imagenet_checkpoint.tar.gz)
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv2_imagenet_savedmodel.tar.gz) * [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v2-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv2_imagenet_frozen_graph.pb) * [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv1_imagenet_checkpoint.tar.gz) * [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv1_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv1_imagenet_frozen_graph.pb)
...@@ -145,7 +145,8 @@ class Cifar10Model(resnet_model.Model): ...@@ -145,7 +145,8 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
Args: Args:
...@@ -156,6 +157,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -156,6 +157,7 @@ class Cifar10Model(resnet_model.Model):
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
Raises: Raises:
ValueError: if invalid resnet_size is chosen ValueError: if invalid resnet_size is chosen
...@@ -180,7 +182,9 @@ class Cifar10Model(resnet_model.Model): ...@@ -180,7 +182,9 @@ class Cifar10Model(resnet_model.Model):
block_strides=[1, 2, 2], block_strides=[1, 2, 2],
final_size=64, final_size=64,
version=version, version=version,
data_format=data_format) data_format=data_format,
dtype=dtype
)
def cifar10_model_fn(features, labels, mode, params): def cifar10_model_fn(features, labels, mode, params):
...@@ -204,15 +208,22 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -204,15 +208,22 @@ def cifar10_model_fn(features, labels, mode, params):
def loss_filter_fn(_): def loss_filter_fn(_):
return True return True
return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model, return resnet_run_loop.resnet_model_fn(
resnet_size=params['resnet_size'], features=features,
weight_decay=weight_decay, labels=labels,
learning_rate_fn=learning_rate_fn, mode=mode,
momentum=0.9, model_class=Cifar10Model,
data_format=params['data_format'], resnet_size=params['resnet_size'],
version=params['version'], weight_decay=weight_decay,
loss_filter_fn=loss_filter_fn, learning_rate_fn=learning_rate_fn,
multi_gpu=params['multi_gpu']) momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)
def main(argv): def main(argv):
......
...@@ -71,38 +71,61 @@ class BaseTest(tf.test.TestCase): ...@@ -71,38 +71,61 @@ class BaseTest(tf.test.TestCase):
for pixel in row: for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3) self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)
def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
with tf.Graph().as_default() as g:
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu
})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def cifar10_model_fn_helper(self, mode, version, multi_gpu=False): def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
input_fn = cifar10_main.get_synth_input_fn() self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
dataset = input_fn(True, '', _BATCH_SIZE) multi_gpu=multi_gpu)
iterator = dataset.make_one_shot_iterator() self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
features, labels = iterator.get_next() multi_gpu=multi_gpu)
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'multi_gpu': multi_gpu
})
predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
def test_cifar10_model_fn_train_mode_v1(self): def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
...@@ -130,19 +153,22 @@ class BaseTest(tf.test.TestCase): ...@@ -130,19 +153,22 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_model_fn_predict_mode_v2(self): def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
def test_cifar10model_shape(self): def _test_cifar10model_shape(self, version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
for version in (1, 2): model = cifar10_main.Cifar10Model(32, data_format='channels_last',
model = cifar10_main.Cifar10Model( num_classes=num_classes, version=version)
32, data_format='channels_last', num_classes=num_classes, fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
version=version) output = model(fake_input, training=True)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS]) self.assertAllEqual(output.shape, (batch_size, num_classes))
output = model(fake_input, training=True)
def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1)
self.assertAllEqual(output.shape, (batch_size, num_classes)) def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2)
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
......
...@@ -203,7 +203,8 @@ class ImagenetModel(resnet_model.Model): ...@@ -203,7 +203,8 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION): version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
Args: Args:
...@@ -214,6 +215,7 @@ class ImagenetModel(resnet_model.Model): ...@@ -214,6 +215,7 @@ class ImagenetModel(resnet_model.Model):
enables users to extend the same model to their own datasets. enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
""" """
# For bigger models, we want to use "bottleneck" layers # For bigger models, we want to use "bottleneck" layers
...@@ -239,7 +241,9 @@ class ImagenetModel(resnet_model.Model): ...@@ -239,7 +241,9 @@ class ImagenetModel(resnet_model.Model):
block_strides=[1, 2, 2, 2], block_strides=[1, 2, 2, 2],
final_size=final_size, final_size=final_size,
version=version, version=version,
data_format=data_format) data_format=data_format,
dtype=dtype
)
def _get_block_sizes(resnet_size): def _get_block_sizes(resnet_size):
...@@ -283,15 +287,22 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -283,15 +287,22 @@ def imagenet_model_fn(features, labels, mode, params):
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])
return resnet_run_loop.resnet_model_fn(features, labels, mode, ImagenetModel, return resnet_run_loop.resnet_model_fn(
resnet_size=params['resnet_size'], features=features,
weight_decay=1e-4, labels=labels,
learning_rate_fn=learning_rate_fn, mode=mode,
momentum=0.9, model_class=ImagenetModel,
data_format=params['data_format'], resnet_size=params['resnet_size'],
version=params['version'], weight_decay=1e-4,
loss_filter_fn=None, learning_rate_fn=learning_rate_fn,
multi_gpu=params['multi_gpu']) momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)
def main(argv): def main(argv):
......
...@@ -36,7 +36,7 @@ class BaseTest(tf.test.TestCase): ...@@ -36,7 +36,7 @@ class BaseTest(tf.test.TestCase):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False): def _tensor_shapes_helper(self, resnet_size, version, dtype, with_gpu):
"""Checks the tensor shapes after each phase of the ResNet model.""" """Checks the tensor shapes after each phase of the ResNet model."""
def reshape(shape): def reshape(shape):
"""Returns the expected dimensions depending on if a GPU is being used.""" """Returns the expected dimensions depending on if a GPU is being used."""
...@@ -50,22 +50,24 @@ class BaseTest(tf.test.TestCase): ...@@ -50,22 +50,24 @@ class BaseTest(tf.test.TestCase):
graph = tf.Graph() graph = tf.Graph()
with graph.as_default(), self.test_session( with graph.as_default(), self.test_session(
use_gpu=with_gpu, force_gpu=with_gpu): graph=graph, use_gpu=with_gpu, force_gpu=with_gpu):
model = imagenet_main.ImagenetModel( model = imagenet_main.ImagenetModel(
resnet_size, resnet_size=resnet_size,
data_format='channels_first' if with_gpu else 'channels_last', data_format='channels_first' if with_gpu else 'channels_last',
version=version) version=version,
dtype=dtype
)
inputs = tf.random_uniform([1, 224, 224, 3]) inputs = tf.random_uniform([1, 224, 224, 3])
output = model(inputs, training=True) output = model(inputs, training=True)
initial_conv = graph.get_tensor_by_name('initial_conv:0') initial_conv = graph.get_tensor_by_name('resnet_model/initial_conv:0')
max_pool = graph.get_tensor_by_name('initial_max_pool:0') max_pool = graph.get_tensor_by_name('resnet_model/initial_max_pool:0')
block_layer1 = graph.get_tensor_by_name('block_layer1:0') block_layer1 = graph.get_tensor_by_name('resnet_model/block_layer1:0')
block_layer2 = graph.get_tensor_by_name('block_layer2:0') block_layer2 = graph.get_tensor_by_name('resnet_model/block_layer2:0')
block_layer3 = graph.get_tensor_by_name('block_layer3:0') block_layer3 = graph.get_tensor_by_name('resnet_model/block_layer3:0')
block_layer4 = graph.get_tensor_by_name('block_layer4:0') block_layer4 = graph.get_tensor_by_name('resnet_model/block_layer4:0')
reduce_mean = graph.get_tensor_by_name('final_reduce_mean:0') reduce_mean = graph.get_tensor_by_name('resnet_model/final_reduce_mean:0')
dense = graph.get_tensor_by_name('final_dense:0') dense = graph.get_tensor_by_name('resnet_model/final_dense:0')
self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112))) self.assertAllEqual(initial_conv.shape, reshape((1, 64, 112, 112)))
self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56))) self.assertAllEqual(max_pool.shape, reshape((1, 64, 56, 56)))
...@@ -88,6 +90,12 @@ class BaseTest(tf.test.TestCase): ...@@ -88,6 +90,12 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(dense.shape, (1, _LABEL_CLASSES))
self.assertAllEqual(output.shape, (1, _LABEL_CLASSES)) self.assertAllEqual(output.shape, (1, _LABEL_CLASSES))
def tensor_shapes_helper(self, resnet_size, version, with_gpu=False):
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
dtype=tf.float32, with_gpu=with_gpu)
self._tensor_shapes_helper(resnet_size=resnet_size, version=version,
dtype=tf.float16, with_gpu=with_gpu)
def test_tensor_shapes_resnet_18_v1(self): def test_tensor_shapes_resnet_18_v1(self):
self.tensor_shapes_helper(18, version=1) self.tensor_shapes_helper(18, version=1)
...@@ -172,41 +180,62 @@ class BaseTest(tf.test.TestCase): ...@@ -172,41 +180,62 @@ class BaseTest(tf.test.TestCase):
def test_tensor_shapes_resnet_200_with_gpu_v2(self): def test_tensor_shapes_resnet_200_with_gpu_v2(self):
self.tensor_shapes_helper(200, version=2, with_gpu=True) self.tensor_shapes_helper(200, version=2, with_gpu=True)
def resnet_model_fn_helper(self, mode, version, multi_gpu=False): def _resnet_model_fn_helper(self, mode, version, dtype, multi_gpu):
"""Tests that the EstimatorSpec is given the appropriate arguments.""" """Tests that the EstimatorSpec is given the appropriate arguments."""
tf.train.create_global_step() with tf.Graph().as_default() as g:
tf.train.create_global_step()
input_fn = imagenet_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE) input_fn = imagenet_main.get_synth_input_fn()
iterator = dataset.make_one_shot_iterator() dataset = input_fn(True, '', _BATCH_SIZE)
features, labels = iterator.get_next() iterator = dataset.make_one_shot_iterator()
spec = imagenet_main.imagenet_model_fn( features, labels = iterator.get_next()
features, labels, mode, { spec = imagenet_main.imagenet_model_fn(
'resnet_size': 50, features, labels, mode, {
'data_format': 'channels_last', 'dtype': dtype,
'batch_size': _BATCH_SIZE, 'resnet_size': 50,
'version': version, 'data_format': 'channels_last',
'multi_gpu': multi_gpu, 'batch_size': _BATCH_SIZE,
}) 'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
predictions = spec.predictions 'multi_gpu': multi_gpu,
self.assertAllEqual(predictions['probabilities'].shape, })
(_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) predictions = spec.predictions
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,)) self.assertAllEqual(predictions['probabilities'].shape,
self.assertEqual(predictions['classes'].dtype, tf.int64) (_BATCH_SIZE, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
if mode != tf.estimator.ModeKeys.PREDICT: self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
loss = spec.loss self.assertEqual(predictions['classes'].dtype, tf.int64)
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32) if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
if mode == tf.estimator.ModeKeys.EVAL: self.assertAllEqual(loss.shape, ())
eval_metric_ops = spec.eval_metric_ops self.assertEqual(loss.dtype, tf.float32)
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ()) if mode == tf.estimator.ModeKeys.EVAL:
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32) eval_metric_ops = spec.eval_metric_ops
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32) self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
tensors_to_check = ('initial_conv:0', 'initial_max_pool:0',
'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'block_layer4:0',
'final_reduce_mean:0', 'final_dense:0')
for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))
def resnet_model_fn_helper(self, mode, version, multi_gpu=False):
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._resnet_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)
def test_resnet_model_fn_train_mode_v1(self): def test_resnet_model_fn_train_mode_v1(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1) self.resnet_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
...@@ -234,18 +263,24 @@ class BaseTest(tf.test.TestCase): ...@@ -234,18 +263,24 @@ class BaseTest(tf.test.TestCase):
def test_resnet_model_fn_predict_mode_v2(self): def test_resnet_model_fn_predict_mode_v2(self):
self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2) self.resnet_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)
def test_imagenetmodel_shape(self): def _test_imagenetmodel_shape(self, version):
batch_size = 135 batch_size = 135
num_classes = 246 num_classes = 246
for version in (1, 2): model = imagenet_main.ImagenetModel(
model = imagenet_main.ImagenetModel( 50, data_format='channels_last', num_classes=num_classes,
50, data_format='channels_last', num_classes=num_classes, version=version)
version=version)
fake_input = tf.random_uniform([batch_size, 224, 224, 3]) fake_input = tf.random_uniform([batch_size, 224, 224, 3])
output = model(fake_input, training=True) output = model(fake_input, training=True)
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenetmodel_shape_v1(self):
self._test_imagenetmodel_shape(version=1)
self.assertAllEqual(output.shape, (batch_size, num_classes)) def test_imagenetmodel_shape_v2(self):
self._test_imagenetmodel_shape(version=2)
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
......
...@@ -36,6 +36,9 @@ import tensorflow as tf ...@@ -36,6 +36,9 @@ import tensorflow as tf
_BATCH_NORM_DECAY = 0.997 _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
DEFAULT_VERSION = 2 DEFAULT_VERSION = 2
DEFAULT_DTYPE = tf.float32
CASTABLE_TYPES = (tf.float16,)
ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES
################################################################################ ################################################################################
...@@ -351,7 +354,8 @@ class Model(object): ...@@ -351,7 +354,8 @@ class Model(object):
kernel_size, kernel_size,
conv_stride, first_pool_size, first_pool_stride, conv_stride, first_pool_size, first_pool_stride,
second_pool_size, second_pool_stride, block_sizes, block_strides, second_pool_size, second_pool_stride, block_sizes, block_strides,
final_size, version=DEFAULT_VERSION, data_format=None): final_size, version=DEFAULT_VERSION, data_format=None,
dtype=DEFAULT_DTYPE):
"""Creates a model for classifying an image. """Creates a model for classifying an image.
Args: Args:
...@@ -379,6 +383,8 @@ class Model(object): ...@@ -379,6 +383,8 @@ class Model(object):
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
data_format: Input format ('channels_last', 'channels_first', or None). data_format: Input format ('channels_last', 'channels_first', or None).
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
dtype: The TensorFlow dtype to use for calculations. If not specified
tf.float32 is used.
Raises: Raises:
ValueError: if invalid version is selected. ValueError: if invalid version is selected.
...@@ -406,6 +412,9 @@ class Model(object): ...@@ -406,6 +412,9 @@ class Model(object):
else: else:
self.block_fn = _building_block_v2 self.block_fn = _building_block_v2
if dtype not in ALLOWED_TYPES:
raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES))
self.data_format = data_format self.data_format = data_format
self.num_classes = num_classes self.num_classes = num_classes
self.num_filters = num_filters self.num_filters = num_filters
...@@ -418,6 +427,61 @@ class Model(object): ...@@ -418,6 +427,61 @@ class Model(object):
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.block_strides = block_strides self.block_strides = block_strides
self.final_size = final_size self.final_size = final_size
self.dtype = dtype
def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE,
*args, **kwargs):
"""Creates variables in fp32, then casts to fp16 if necessary.
This function is a custom getter. A custom getter is a function with the
same signature as tf.get_variable, except it has an additional getter
parameter. Custom getters can be passed as the `custom_getter` parameter of
tf.variable_scope. Then, tf.get_variable will call the custom getter,
instead of directly getting a variable itself. This can be used to change
the types of variables that are retrieved with tf.get_variable.
The `getter` parameter is the underlying variable getter, that would have
been called if no custom getter was used. Custom getters typically get a
variable with `getter`, then modify it in some way.
This custom getter will create an fp32 variable. If a low precision
(e.g. float16) variable was requested it will then cast the variable to the
requested dtype. The reason we do not directly create variables in low
precision dtypes is that applying small gradients to such variables may
cause the variable not to change.
Args:
getter: The underlying variable getter, that has the same signature as
tf.get_variable and returns a variable.
name: The name of the variable to get.
shape: The shape of the variable to get.
dtype: The dtype of the variable to get. Note that if this is a low
precision dtype, the variable will be created as a tf.float32 variable,
then cast to the appropriate dtype
*args: Additional arguments to pass unmodified to getter.
**kwargs: Additional keyword arguments to pass unmodified to getter.
Returns:
A variable which is cast to fp16 if necessary.
"""
if dtype in CASTABLE_TYPES:
var = getter(name, shape, tf.float32, *args, **kwargs)
return tf.cast(var, dtype=dtype, name=name + '_cast')
else:
return getter(name, shape, dtype, *args, **kwargs)
def _model_variable_scope(self):
"""Returns a variable scope that the model should be created under.
If self.dtype is a castable type, model variable will be created in fp32
then cast to self.dtype before being used.
Returns:
A variable scope for the model.
"""
return tf.variable_scope('resnet_model',
custom_getter=self._custom_dtype_getter)
def __call__(self, inputs, training): def __call__(self, inputs, training):
"""Add operations to classify a batch of input images. """Add operations to classify a batch of input images.
...@@ -431,46 +495,46 @@ class Model(object): ...@@ -431,46 +495,46 @@ class Model(object):
A logits Tensor with shape [<batch_size>, self.num_classes]. A logits Tensor with shape [<batch_size>, self.num_classes].
""" """
if self.data_format == 'channels_first': with self._model_variable_scope():
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW). if self.data_format == 'channels_first':
# This provides a large performance boost on GPU. See # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
# https://www.tensorflow.org/performance/performance_guide#data_formats # This provides a large performance boost on GPU. See
inputs = tf.transpose(inputs, [0, 3, 1, 2]) # https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])
inputs = conv2d_fixed_padding(
inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size, inputs = conv2d_fixed_padding(
strides=self.conv_stride, data_format=self.data_format) inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size,
inputs = tf.identity(inputs, 'initial_conv') strides=self.conv_stride, data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_conv')
if self.first_pool_size:
inputs = tf.layers.max_pooling2d( if self.first_pool_size:
inputs=inputs, pool_size=self.first_pool_size, inputs = tf.layers.max_pooling2d(
strides=self.first_pool_stride, padding='SAME', inputs=inputs, pool_size=self.first_pool_size,
data_format=self.data_format) strides=self.first_pool_stride, padding='SAME',
inputs = tf.identity(inputs, 'initial_max_pool') data_format=self.data_format)
inputs = tf.identity(inputs, 'initial_max_pool')
for i, num_blocks in enumerate(self.block_sizes):
num_filters = self.num_filters * (2**i) for i, num_blocks in enumerate(self.block_sizes):
inputs = block_layer( num_filters = self.num_filters * (2**i)
inputs=inputs, filters=num_filters, bottleneck=self.bottleneck, inputs = block_layer(
block_fn=self.block_fn, blocks=num_blocks, inputs=inputs, filters=num_filters, bottleneck=self.bottleneck,
strides=self.block_strides[i], training=training, block_fn=self.block_fn, blocks=num_blocks,
name='block_layer{}'.format(i + 1), data_format=self.data_format) strides=self.block_strides[i], training=training,
name='block_layer{}'.format(i + 1), data_format=self.data_format)
inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs) inputs = batch_norm(inputs, training, self.data_format)
inputs = tf.nn.relu(inputs)
# The current top layer has shape
# `batch_size x pool_size x pool_size x final_size`. # The current top layer has shape
# ResNet does an Average Pooling layer over pool_size, # `batch_size x pool_size x pool_size x final_size`.
# but that is the same as doing a reduce_mean. We do a reduce_mean # ResNet does an Average Pooling layer over pool_size,
# here because it performs better than AveragePooling2D. # but that is the same as doing a reduce_mean. We do a reduce_mean
axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] # here because it performs better than AveragePooling2D.
inputs = tf.reduce_mean(inputs, axes, keepdims=True) axes = [2, 3] if self.data_format == 'channels_first' else [1, 2]
inputs = tf.identity(inputs, 'final_reduce_mean') inputs = tf.reduce_mean(inputs, axes, keepdims=True)
inputs = tf.identity(inputs, 'final_reduce_mean')
inputs = tf.reshape(inputs, [-1, self.final_size])
inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) inputs = tf.reshape(inputs, [-1, self.final_size])
inputs = tf.identity(inputs, 'final_dense') inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
...@@ -170,7 +170,9 @@ def learning_rate_with_decay( ...@@ -170,7 +170,9 @@ def learning_rate_with_decay(
def resnet_model_fn(features, labels, mode, model_class, def resnet_model_fn(features, labels, mode, model_class,
resnet_size, weight_decay, learning_rate_fn, momentum, resnet_size, weight_decay, learning_rate_fn, momentum,
data_format, version, loss_filter_fn=None, multi_gpu=False): data_format, version, loss_scale,
loss_filter_fn=None, multi_gpu=False,
dtype=resnet_model.DEFAULT_DTYPE):
"""Shared functionality for different resnet model_fns. """Shared functionality for different resnet model_fns.
Initializes the ResnetModel representing the model layers Initializes the ResnetModel representing the model layers
...@@ -196,12 +198,15 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -196,12 +198,15 @@ def resnet_model_fn(features, labels, mode, model_class,
If set to None, the format is dependent on whether a GPU is available. If set to None, the format is dependent on whether a GPU is available.
version: Integer representing which version of the ResNet network to use. version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2] See README for details. Valid values: [1, 2]
loss_scale: The factor to scale the loss for numerical stability. A detailed
summary is present in the arg parser help text.
loss_filter_fn: function that takes a string variable name and returns loss_filter_fn: function that takes a string variable name and returns
True if the var should be included in loss calculation, and False True if the var should be included in loss calculation, and False
otherwise. If None, batch_normalization variables will be excluded otherwise. If None, batch_normalization variables will be excluded
from the loss. from the loss.
multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for multi_gpu: If True, wrap the optimizer in a TowerOptimizer suitable for
data-parallel distribution across multiple GPUs. data-parallel distribution across multiple GPUs.
dtype: the TensorFlow dtype to use for calculations.
Returns: Returns:
EstimatorSpec parameterized according to the input params and the EstimatorSpec parameterized according to the input params and the
...@@ -211,9 +216,17 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -211,9 +216,17 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images # Generate a summary node for the images
tf.summary.image('images', features, max_outputs=6) tf.summary.image('images', features, max_outputs=6)
model = model_class(resnet_size, data_format, version=version) features = tf.cast(features, dtype)
model = model_class(resnet_size, data_format, version=version, dtype=dtype)
logits = model(features, mode == tf.estimator.ModeKeys.TRAIN) logits = model(features, mode == tf.estimator.ModeKeys.TRAIN)
# This acts as a no-op if the logits are already in fp32 (provided logits are
# not a SparseTensor). If dtype is is low precision, logits must be cast to
# fp32 for numerical stability.
logits = tf.cast(logits, tf.float32)
predictions = { predictions = {
'classes': tf.argmax(logits, axis=1), 'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
...@@ -244,7 +257,8 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -244,7 +257,8 @@ def resnet_model_fn(features, labels, mode, model_class,
# Add weight decay to the loss. # Add weight decay to the loss.
l2_loss = weight_decay * tf.add_n( l2_loss = weight_decay * tf.add_n(
[tf.nn.l2_loss(v) for v in tf.trainable_variables() # loss is computed using fp32 for numerical stability.
[tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()
if loss_filter_fn(v.name)]) if loss_filter_fn(v.name)])
tf.summary.scalar('l2_loss', l2_loss) tf.summary.scalar('l2_loss', l2_loss)
loss = cross_entropy + l2_loss loss = cross_entropy + l2_loss
...@@ -266,8 +280,22 @@ def resnet_model_fn(features, labels, mode, model_class, ...@@ -266,8 +280,22 @@ def resnet_model_fn(features, labels, mode, model_class,
if multi_gpu: if multi_gpu:
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
if loss_scale != 1:
# When computing fp16 gradients, often intermediate tensor values are
# so small, they underflow to 0. To avoid this, we multiply the loss by
# loss_scale to make these tensor values loss_scale times bigger.
scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale)
# Once the gradient computation is complete we can scale the gradients
# back to the correct scale before passing them to the optimizer.
unscaled_grad_vars = [(grad / loss_scale, var)
for grad, var in scaled_grad_vars]
minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step)
else:
minimize_op = optimizer.minimize(loss, global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
train_op = tf.group(optimizer.minimize(loss, global_step), update_ops) train_op = tf.group(minimize_op, update_ops)
else: else:
train_op = None train_op = None
...@@ -365,11 +393,13 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -365,11 +393,13 @@ def resnet_main(flags, model_function, input_function, shape=None):
'batch_size': flags.batch_size, 'batch_size': flags.batch_size,
'multi_gpu': flags.multi_gpu, 'multi_gpu': flags.multi_gpu,
'version': flags.version, 'version': flags.version,
'loss_scale': flags.loss_scale,
'dtype': flags.dtype
}) })
if flags.benchmark_log_dir is not None: if flags.benchmark_log_dir is not None:
benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir) benchmark_logger = logger.BenchmarkLogger(flags.benchmark_log_dir)
benchmark_logger.log_run_info("resnet") benchmark_logger.log_run_info('resnet')
else: else:
benchmark_logger = None benchmark_logger = None
...@@ -451,3 +481,12 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -451,3 +481,12 @@ class ResnetArgParser(argparse.ArgumentParser):
help='[default: %(default)s] The size of the ResNet model to use.', help='[default: %(default)s] The size of the ResNet model to use.',
metavar='<RS>' if resnet_size_choices is None else None metavar='<RS>' if resnet_size_choices is None else None
) )
def parse_args(self, args=None, namespace=None):
args = super(ResnetArgParser, self).parse_args(
args=args, namespace=namespace)
# handle coupling between dtype and loss_scale
parsers.parse_dtype_info(args)
return args
...@@ -58,9 +58,37 @@ from __future__ import absolute_import ...@@ -58,9 +58,37 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import tensorflow as tf
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def parse_dtype_info(flags):
"""Convert dtype string to tf dtype, and set loss_scale default as needed.
Args:
flags: namespace object returned by arg parser.
Raises:
ValueError: If an invalid dtype is provided.
"""
if flags.dtype in (i[0] for i in DTYPE_MAP.values()):
return # Make function idempotent
try:
flags.dtype, default_loss_scale = DTYPE_MAP[flags.dtype]
except KeyError:
raise ValueError("Invalid dtype: {}".format(flags.dtype))
flags.loss_scale = flags.loss_scale or default_loss_scale
class BaseParser(argparse.ArgumentParser): class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models. """Parser to contain flags which will be nearly universal across models.
...@@ -148,7 +176,8 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -148,7 +176,8 @@ class PerformanceParser(argparse.ArgumentParser):
""" """
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True, def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True): intra_op=True, use_synthetic_data=True, max_train_steps=True,
dtype=True):
super(PerformanceParser, self).__init__(add_help=add_help) super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls: if num_parallel_calls:
...@@ -201,6 +230,31 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -201,6 +230,31 @@ class PerformanceParser(argparse.ArgumentParser):
metavar="<MTS>" metavar="<MTS>"
) )
if dtype:
self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=list(DTYPE_MAP.keys()),
help="[default: %(default)s] {%(choices)s} The TensorFlow datatype "
"used for calculations. Variables may be cast to a higher"
"precision on a case-by-case basis for numerical stability.",
metavar="<DT>"
)
self.add_argument(
"--loss_scale", "-ls",
type=int,
help="[default: %(default)s] The amount to scale the loss by when "
"the model is run. Before gradients are computed, the loss is "
"multiplied by the loss scale, making all gradients loss_scale "
"times larger. To adjust for this, gradients are divided by the "
"loss scale before being applied to variables. This is "
"mathematically equivalent to training without a loss scale, "
"but the loss scale helps avoid some intermediate gradients "
"from underflowing to zero. If not provided the default for "
"fp16 is 128 and 1 for all other dtypes.",
)
class ImageModelParser(argparse.ArgumentParser): class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior. """Default parser for specification image specific behavior.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import unittest import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
...@@ -83,6 +84,24 @@ class BaseTester(unittest.TestCase): ...@@ -83,6 +84,24 @@ class BaseTester(unittest.TestCase):
assert namespace.multi_gpu assert namespace.multi_gpu
assert namespace.use_synthetic_data assert namespace.use_synthetic_data
def test_parse_dtype_info(self):
parser = TestParser()
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
["fp32", tf.float32, 1]]:
args = parser.parse_args(["--dtype", dtype_str])
parsers.parse_dtype_info(args)
assert args.dtype == tf_dtype
assert args.loss_scale == loss_scale
args = parser.parse_args(["--dtype", dtype_str, "--loss_scale", "5"])
parsers.parse_dtype_info(args)
assert args.loss_scale == 5
with self.assertRaises(SystemExit):
parser.parse_args(["--dtype", "int8"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment