Commit 811f3940 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

exporter final version

parent 6c903183
......@@ -54,8 +54,11 @@ class FakeModel(model.DetectionModel):
true_image_shapes = [] # Doesn't matter for the fake model.
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)}
def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs):
return_dict = {'image': self._conv(preprocessed_inputs)}
if 'side_inp_1' in side_inputs:
return_dict['image'] += side_inputs['side_inp_1']
return return_dict
def postprocess(self, prediction_dict, true_image_shapes):
predict_tensor_sum = tf.reduce_sum(prediction_dict['image'])
......@@ -142,9 +145,9 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return np.zeros(shape=(1, 20, 20, 3), dtype=np.uint8)
return np.zeros((1, 20, 20, 3), dtype=np.uint8)
if input_type == 'float_image_tensor':
return np.zeros(shape=(1, 20, 20, 3), dtype=np.float32)
return np.zeros((1, 20, 20, 3), dtype=np.float32)
elif input_type == 'encoded_image_string_tensor':
image = Image.new('RGB', (20, 20))
byte_io = io.BytesIO()
......@@ -189,7 +192,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
image = self.get_dummy_input(input_type)
detections = detect_fn(image)
detections = detect_fn(tf.constant(image))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
......@@ -203,6 +206,54 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
[[1, 2], [2, 1]])
self.assertAllClose(detections[detection_fields.num_detections], [2, 1])
@parameterized.parameters(
{'use_default_serving': True},
{'use_default_serving': False}
)
def test_export_saved_model_and_run_inference_with_side_inputs(
self, input_type='image_tensor', use_default_serving=True):
tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(tmp_dir)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
input_type=input_type,
pipeline_config=pipeline_config,
trained_checkpoint_dir=tmp_dir,
output_directory=output_directory,
use_side_inputs=True,
side_input_shapes="1/2,2",
side_input_names="side_inp_1,side_inp_2",
side_input_types="tf.float32,tf.uint8")
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
detect_fn_sig = detect_fn.signatures['serving_default']
image = tf.constant(self.get_dummy_input(input_type))
side_input_1 = np.ones((1,), dtype=np.float32)
side_input_2 = np.ones((2, 2), dtype=np.uint8)
if (use_default_serving):
detections = detect_fn_sig(input_tensor=image,
side_inp_1=tf.constant(side_input_1),
side_inp_2=tf.constant(side_input_2))
else:
detections = detect_fn(image, tf.constant(side_input_1), tf.constant(side_input_2))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
[[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(detections[detection_fields.detection_scores],
[[400.7, 400.6], [400.9, 400.0]])
self.assertAllClose(detections[detection_fields.detection_classes],
[[1, 2], [2, 1]])
self.assertAllClose(detections[detection_fields.num_detections], [2, 1])
def test_export_checkpoint_and_run_inference_with_image(self):
tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(tmp_dir, conv_weight_scalar=2.0)
......
......@@ -16,13 +16,14 @@
"""Functions to export object detection inference graph."""
import os
import ast
import tensorflow.compat.v2 as tf
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util
def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
......@@ -37,18 +38,56 @@ def _decode_tf_example(tf_example_string_tensor):
return image_tensor
def _combine_side_inputs(side_input_shapes="",
side_input_types="",
side_input_names=""):
"""Zips the side inputs together.
Args:
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Returns:
a zipped list of side input tuples.
"""
side_input_shapes = list(map(lambda x: ast.literal_eval('[' + x + ']'),
side_input_shapes.split('/')))
side_input_types = eval('[' + side_input_types + ']')
side_input_names = side_input_names.split(',')
return zip(side_input_shapes, side_input_types, side_input_names)
class DetectionInferenceModule(tf.Module):
"""Detection Inference Module."""
def __init__(self, detection_model):
def __init__(self, detection_model,
use_side_inputs=False,
zipped_side_inputs=[]):
"""Initializes a module for detection.
Args:
detection_model: The detection model to use for inference.
detection_model: the detection model to use for inference.
use_side_inputs: whether to use side inputs.
zipped_side_inputs: the zipped side inputs.
"""
self._model = detection_model
def _run_inference_on_images(self, image):
def _get_side_input_signature(self, zipped_side_inputs):
sig = []
side_input_names = []
for info in zipped_side_inputs:
sig.append(tf.TensorSpec(shape=info[0],
dtype=info[1],
name=info[2]))
side_input_names.append(info[2])
return sig
def _get_side_names_from_zip(self, zipped_side_inputs):
return [side[2] for side in zipped_side_inputs]
def _run_inference_on_images(self, image, **kwargs):
"""Cast image to float and run inference.
Args:
......@@ -60,7 +99,7 @@ class DetectionInferenceModule(tf.Module):
image = tf.cast(image, tf.float32)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, shapes)
prediction_dict = self._model.predict(image, shapes, **kwargs)
detections = self._model.postprocess(prediction_dict, shapes)
classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = (
......@@ -75,11 +114,34 @@ class DetectionInferenceModule(tf.Module):
class DetectionFromImageModule(DetectionInferenceModule):
"""Detection Inference Module for image inputs."""
@tf.function(
input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)])
def __call__(self, input_tensor):
return self._run_inference_on_images(input_tensor)
def __init__(self, detection_model,
use_side_inputs=False,
zipped_side_inputs=None):
"""Initializes a module for detection.
Args:
detection_model: the detection model to use for inference.
use_side_inputs: whether to use side inputs.
zipped_side_inputs: the zipped side inputs.
"""
if zipped_side_inputs is None:
zipped_side_inputs = []
sig = [tf.TensorSpec(shape=[1, None, None, 3],
dtype=tf.uint8,
name='input_tensor')]
if use_side_inputs:
sig.extend(self._get_side_input_signature(zipped_side_inputs))
self._side_input_names = self._get_side_names_from_zip(zipped_side_inputs)
def call_func(input_tensor, *side_inputs):
kwargs = dict(zip(self._side_input_names, side_inputs))
return self._run_inference_on_images(input_tensor, **kwargs)
self.__call__ = tf.function(call_func, input_signature=sig)
super(DetectionFromImageModule, self).__init__(detection_model,
use_side_inputs,
zipped_side_inputs)
class DetectionFromFloatImageModule(DetectionInferenceModule):
......@@ -133,7 +195,11 @@ DETECTION_MODULE_MAP = {
def export_inference_graph(input_type,
pipeline_config,
trained_checkpoint_dir,
output_directory):
output_directory,
use_side_inputs=False,
side_input_shapes="",
side_input_types="",
side_input_names=""):
"""Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist,
......@@ -147,6 +213,12 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
use_side_inputs: boolean that determines whether side inputs should be
included in the input signature.
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Raises:
ValueError: if input_type is invalid.
"""
......@@ -164,7 +236,18 @@ def export_inference_graph(input_type,
if input_type not in DETECTION_MODULE_MAP:
raise ValueError('Unrecognized `input_type`')
detection_module = DETECTION_MODULE_MAP[input_type](detection_model)
if use_side_inputs and input_type != 'image_tensor':
raise ValueError('Side inputs supported for image_tensor input type only.')
zipped_side_inputs = []
if use_side_inputs:
zipped_side_inputs = _combine_side_inputs(side_input_shapes,
side_input_types,
side_input_names)
detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs,
list(zipped_side_inputs))
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
......
......@@ -50,6 +50,10 @@ python exporter_main_v2.py \
--pipeline_config_path path/to/ssd_inception_v2.config \
--trained_checkpoint_dir path/to/checkpoint \
--output_directory path/to/exported_model_directory
--use_side_inputs True/False \
--side_input_shapes dim_0,dim_1,...dim_a/.../dim_0,dim_1,...,dim_z \
--side_input_names name_a,name_b,...,name_c \
--side_input_types type_1,type_2
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
......@@ -80,6 +84,13 @@ python exporter_main_v2.py \
} \
} \
}"
If side inputs are desired, the following arguments could be appended
(the example below is for Context R-CNN).
--use_side_inputs True \
--side_input_shapes 1,2000,2057/1 \
--side_input_names context_features,valid_context_size \
--side_input_types tf.float32,tf.int32
"""
from absl import app
from absl import flags
......@@ -106,6 +117,27 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '',
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
flags.DEFINE_boolean('use_side_inputs', False,
'If True, uses side inputs as well as image inputs.')
flags.DEFINE_string('side_input_shapes', "",
'If use_side_inputs is True, this explicitly sets '
'the shape of the side input tensors to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. A `/` denotes a break, starting the shape of '
'the next side input tensor. This flag is required if '
'using side inputs.')
flags.DEFINE_string('side_input_types', "",
'If use_side_inputs is True, this explicitly sets '
'the type of the side input tensors. The '
'dimensions are to be provided as a comma-separated list '
'of types, each of `string`, `integer`, or `float`. '
'This flag is required if using side inputs.')
flags.DEFINE_string('side_input_names', "",
'If use_side_inputs is True, this explicitly sets '
'the names of the side input tensors required by the model '
'assuming the names will be a comma-separated list of '
'strings. This flag is required if using side inputs.')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_dir')
......@@ -119,7 +151,8 @@ def main(_):
text_format.Merge(FLAGS.config_override, pipeline_config)
exporter_lib_v2.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_dir,
FLAGS.output_directory)
FLAGS.output_directory, FLAGS.use_side_inputs, FLAGS.side_input_shapes,
FLAGS.side_input_types, FLAGS.side_input_names)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment