Unverified Commit eb73a850 authored by Karmel Allison's avatar Karmel Allison Committed by GitHub
Browse files

Add SavedModel export to Resnet (#3759)

* Adding export_dir and model saving for Resnet

* Moving to utils for tests

* Adding batch_size

* Adding multi-gpu export warning

* Responding to CR

* Py3 compliance
parent 1bfe1df1
......@@ -246,14 +246,9 @@ class MNISTArgParser(argparse.ArgumentParser):
def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser()])
self.add_argument(
'--export_dir',
type=str,
help='[default: %(default)s] If set, a SavedModel serialization of the '
'model will be exported to this directory at the end of training. '
'See the README for more details and relevant links.')
parsers.ImageModelParser(),
parsers.ExportParser(),
])
self.set_defaults(
data_dir='/tmp/mnist_data',
......
......@@ -228,7 +228,10 @@ def main(argv):
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
resnet_run_loop.resnet_main(
flags, cifar10_model_fn, input_function,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
if __name__ == '__main__':
......
......@@ -305,7 +305,10 @@ def main(argv):
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
resnet_run_loop.resnet_main(
flags, imagenet_model_fn, input_function,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
if __name__ == '__main__':
......
......@@ -30,6 +30,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model
from official.utils.arg_parsers import parsers
from official.utils.export import export
from official.utils.logging import hooks_helper
from official.utils.logging import logger
......@@ -219,7 +220,13 @@ def resnet_model_fn(features, labels, mode, model_class,
}
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
# Return the predictions and the specification for serving a SavedModel
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
export_outputs={
'predict': tf.estimator.export.PredictOutput(predictions)
})
# Calculate loss, which includes softmax cross entropy and L2 regularization.
cross_entropy = tf.losses.softmax_cross_entropy(
......@@ -310,8 +317,20 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err)
def resnet_main(flags, model_function, input_function):
"""Shared main loop for ResNet Models."""
def resnet_main(flags, model_function, input_function, shape=None):
"""Shared main loop for ResNet Models.
Args:
flags: FLAGS object that contains the params for running. See
ResnetArgParser for created flags.
model_function: the function that instantiates the Model and builds the
ops for train/eval. This will be passed directly into the estimator.
input_function: the function that processes the dataset and returns a
dataset that the estimator can train on. This will be wrapped with
all the relevant flags for running and passed to estimator.
shape: list of ints representing the shape of the images used for training.
This is only used if flags.export_dir is passed.
"""
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
......@@ -389,16 +408,34 @@ def resnet_main(flags, model_function, input_function):
if benchmark_logger:
benchmark_logger.log_estimator_evaluation_result(eval_results)
if flags.export_dir is not None:
warn_on_multi_gpu_export(flags.multi_gpu)
# Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags.batch_size)
classifier.export_savedmodel(flags.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False):
"""For the time being, multi-GPU mode does not play nicely with exporting."""
if multi_gpu:
tf.logging.warning(
'You are exporting a SavedModel while in multi-GPU mode. Note that '
'the resulting SavedModel will require the same GPUs be available.'
'If you wish to serve the SavedModel from a different device, '
'try exporting the SavedModel with multi-GPU mode turned off.')
class ResnetArgParser(argparse.ArgumentParser):
"""Arguments for configuring and running a Resnet Model.
"""
"""Arguments for configuring and running a Resnet Model."""
def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.PerformanceParser(),
parsers.ImageModelParser(),
parsers.ExportParser(),
parsers.BenchmarkParser(),
])
......
......@@ -226,6 +226,29 @@ class ImageModelParser(argparse.ArgumentParser):
)
class ExportParser(argparse.ArgumentParser):
"""Parsing options for exporting saved models or other graph defs.
This is a separate parser for now, but should be made part of BaseParser
once all models are brought up to speed.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""
def __init__(self, add_help=False, export_dir=True):
super(ExportParser, self).__init__(add_help=add_help)
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
class BenchmarkParser(argparse.ArgumentParser):
"""Default parser for benchmark logging.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience functions for exporting models as SavedModels or other types."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def build_tensor_serving_input_receiver_fn(shape, dtype=tf.float32,
batch_size=1):
"""Returns a input_receiver_fn that can be used during serving.
This expects examples to come through as float tensors, and simply
wraps them as TensorServingInputReceivers.
Arguably, this should live in tf.estimator.export. Testing here first.
Args:
shape: list representing target size of a single example.
dtype: the expected datatype for the input example
batch_size: number of input tensors that will be passed for prediction
Returns:
A function that itself returns a TensorServingInputReceiver.
"""
def serving_input_receiver_fn():
# Prep a placeholder where the input example will be fed in
features = tf.placeholder(
dtype=dtype, shape=[batch_size] + shape, name='input_tensor')
return tf.estimator.export.TensorServingInputReceiver(
features=features, receiver_tensors=features)
return serving_input_receiver_fn
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for exporting utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.export import export
class ExportUtilsTest(tf.test.TestCase):
"""Tests for the ExportUtils."""
def test_build_tensor_serving_input_receiver_fn(self):
receiver_fn = export.build_tensor_serving_input_receiver_fn(shape=[4, 5])
with tf.Graph().as_default():
receiver = receiver_fn()
self.assertIsInstance(
receiver, tf.estimator.export.TensorServingInputReceiver)
self.assertIsInstance(receiver.features, tf.Tensor)
self.assertEqual(receiver.features.shape, tf.TensorShape([1, 4, 5]))
self.assertEqual(receiver.features.dtype, tf.float32)
self.assertIsInstance(receiver.receiver_tensors, dict)
# Note that Python 3 can no longer index .values() directly; cast to list.
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
tf.TensorShape([1, 4, 5]))
def test_build_tensor_serving_input_receiver_fn_batch_dtype(self):
receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape=[4, 5], dtype=tf.int8, batch_size=10)
with tf.Graph().as_default():
receiver = receiver_fn()
self.assertIsInstance(
receiver, tf.estimator.export.TensorServingInputReceiver)
self.assertIsInstance(receiver.features, tf.Tensor)
self.assertEqual(receiver.features.shape, tf.TensorShape([10, 4, 5]))
self.assertEqual(receiver.features.dtype, tf.int8)
self.assertIsInstance(receiver.receiver_tensors, dict)
# Note that Python 3 can no longer index .values() directly; cast to list.
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape,
tf.TensorShape([10, 4, 5]))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment