Commit 295259c3 authored by netfs's avatar netfs Committed by Taylor Robie
Browse files

Add `--image_bytes_as_serving_input` flag to export SavedModel (#5393)

with serving signature that accepts JPEG image bytes instead
of a fixed size [HxWxC] image tensor.

Passing JPEG image bytes is easier for inference/serving use
cases. The model internally resizes/crops the JPEG image to
required [HxWxC] tensor before passing it on for actual model
inference.

This change aligns with Cloud TPU/ResNet-50 model that offers a
similar interface (jpeg bytes) for inferencing here:

https://github.com/tensorflow/tpu/tree/master/models/official/resnet

NOTE: This flag is set to `True` by default for ImageNet, and is
disallowed for CIFAR (as it does not apply to CIFAR).
parent 38385b0a
......@@ -238,7 +238,8 @@ def define_cifar_flags():
resnet_size='56',
train_epochs=182,
epochs_between_evals=10,
batch_size=128)
batch_size=128,
image_bytes_as_serving_input=False)
def run_cifar(flags_obj):
......@@ -247,6 +248,11 @@ def run_cifar(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
if flags_obj.image_bytes_as_serving_input:
tf.logging.fatal('--image_bytes_as_serving_input cannot be set to True '
'for CIFAR. This flag is only applicable to ImageNet.')
return
input_function = (flags_obj.use_synthetic_data and
get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or
input_fn)
......
......@@ -23,6 +23,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import math
import os
......@@ -35,6 +36,7 @@ from official.utils.flags import core as flags_core
from official.utils.export import export
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.resnet import imagenet_preprocessing
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
# pylint: enable=g-bad-import-order
......@@ -154,6 +156,26 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
return input_fn
def image_bytes_serving_input_fn(image_shape):
"""Serving input fn for raw jpeg images."""
def _preprocess_image(image_bytes):
"""Preprocess a single raw image."""
# Bounding box around the whole image.
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
height, width, num_channels = image_shape
image = imagenet_preprocessing.preprocess_image(
image_bytes, bbox, height, width, num_channels, is_training=False)
return image
image_bytes_list = tf.placeholder(
shape=[None], dtype=tf.string, name='input_tensor')
images = tf.map_fn(
_preprocess_image, image_bytes_list, back_prop=False, dtype=tf.float32)
return tf.estimator.export.TensorServingInputReceiver(
images, {'image_bytes': image_bytes_list})
################################################################################
# Functions for running training/eval/validation loops for the model.
################################################################################
......@@ -508,6 +530,9 @@ def resnet_main(
if flags_obj.export_dir is not None:
# Exports a saved model for the given classifier.
if flags_obj.image_bytes_as_serving_input:
input_receiver_fn = functools.partial(image_bytes_serving_input_fn, shape)
else:
input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags_obj.batch_size)
classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
......@@ -539,6 +564,15 @@ def define_resnet_flags(resnet_size_choices=None):
name='eval_only', default=False,
help=flags_core.help_wrap('Skip training and only perform evaluation on '
'the latest checkpoint.'))
flags.DEFINE_boolean(
name="image_bytes_as_serving_input", default=True,
help=flags_core.help_wrap(
'If True exports savedmodel with serving signature that accepts '
'JPEG image bytes instead of a fixed size [HxWxC] tensor that '
'represents the image. The former is easier to use for serving at '
'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.'))
choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment