Commit a3ae1258 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

add files for exporter

parent c9da7a38
......@@ -49,13 +49,18 @@ class FakeModel(model.DetectionModel):
filters=1, kernel_size=1, strides=(1, 1), padding='valid',
kernel_initializer=tf.keras.initializers.Constant(
value=conv_weight_scalar))
#self._conv(tf.ones([1, 10, 10, 3]))
def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model.
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)}
def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs):
return_dict = {'image': self._conv(preprocessed_inputs)}
print("SIDE INPUTS: ", side_inputs)
if 'side_inp' in side_inputs:
return_dict['image'] += side_inputs['side_inp']
return return_dict
def postprocess(self, prediction_dict, true_image_shapes):
predict_tensor_sum = tf.reduce_sum(prediction_dict['image'])
......@@ -189,7 +194,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
image = self.get_dummy_input(input_type)
detections = detect_fn(image)
detections = detect_fn.signatures['serving_default'](tf.constant(image))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
......@@ -203,6 +208,45 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
[[1, 2], [2, 1]])
self.assertAllClose(detections[detection_fields.num_detections], [2, 1])
def test_export_saved_model_and_run_inference_with_side_inputs(
self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(tmp_dir)
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
input_type=input_type,
pipeline_config=pipeline_config,
trained_checkpoint_dir=tmp_dir,
output_directory=output_directory,
use_side_inputs=True,
side_input_shapes="1",
side_input_names="side_inp",
side_input_types="tf.float32")
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
detect_fn_sig = detect_fn.signatures['serving_default']
image = tf.constant(self.get_dummy_input(input_type))
side_input = np.ones((1,), dtype=np.float32)
#detections_one = tf.saved_model.load(saved_model_path)(image, side_input)
detections = detect_fn_sig(input_tensor=image, side_inp=tf.constant(side_input))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
[[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(detections[detection_fields.detection_scores],
[[400.7, 400.6], [400.9, 400.0]])
self.assertAllClose(detections[detection_fields.detection_classes],
[[1, 2], [2, 1]])
self.assertAllClose(detections[detection_fields.num_detections], [2, 1])
def test_export_checkpoint_and_run_inference_with_image(self):
tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(tmp_dir, conv_weight_scalar=2.0)
......
......@@ -40,7 +40,11 @@ def _decode_tf_example(tf_example_string_tensor):
class DetectionInferenceModule(tf.Module):
"""Detection Inference Module."""
def __init__(self, detection_model):
def __init__(self, detection_model,
use_side_inputs=False,
side_input_shapes=None,
side_input_types=None,
side_input_names=None):
"""Initializes a module for detection.
Args:
......@@ -48,7 +52,7 @@ class DetectionInferenceModule(tf.Module):
"""
self._model = detection_model
def _run_inference_on_images(self, image):
def _run_inference_on_images(self, image, **kwargs):
"""Cast image to float and run inference.
Args:
......@@ -60,7 +64,7 @@ class DetectionInferenceModule(tf.Module):
image = tf.cast(image, tf.float32)
image, shapes = self._model.preprocess(image)
prediction_dict = self._model.predict(image, shapes)
prediction_dict = self._model.predict(image, shapes, **kwargs)
detections = self._model.postprocess(prediction_dict, shapes)
classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = (
......@@ -71,15 +75,39 @@ class DetectionInferenceModule(tf.Module):
return detections
class DetectionFromImageModule(DetectionInferenceModule):
"""Detection Inference Module for image inputs."""
@tf.function(
input_signature=[
tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)])
def __call__(self, input_tensor):
return self._run_inference_on_images(input_tensor)
def __init__(self, detection_model,
use_side_inputs=False,
side_input_shapes="",
side_input_types="",
side_input_names=""):
"""Initializes a module for detection.
Args:
detection_model: The detection model to use for inference.
"""
self.side_input_names = side_input_names
sig = [tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)]
if use_side_inputs:
for info in zip(side_input_shapes.split("/"),
side_input_types.split(","),
side_input_names.split(",")):
sig.append(tf.TensorSpec(shape=eval("[" + info[0] + "]"),
dtype=eval(info[1]),
name=info[2]))
def __call__(input_tensor, *side_inputs):
kwargs = dict(zip(self.side_input_names.split(","), side_inputs))
return self._run_inference_on_images(input_tensor, **kwargs)
self.__call__ = tf.function(__call__, input_signature=sig)
super(DetectionFromImageModule, self).__init__(detection_model,
side_input_shapes,
side_input_types,
side_input_names)
class DetectionFromFloatImageModule(DetectionInferenceModule):
......@@ -133,7 +161,11 @@ DETECTION_MODULE_MAP = {
def export_inference_graph(input_type,
pipeline_config,
trained_checkpoint_dir,
output_directory):
output_directory,
use_side_inputs=False,
side_input_shapes="",
side_input_types="",
side_input_names=""):
"""Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist,
......@@ -164,7 +196,13 @@ def export_inference_graph(input_type,
if input_type not in DETECTION_MODULE_MAP:
raise ValueError('Unrecognized `input_type`')
detection_module = DETECTION_MODULE_MAP[input_type](detection_model)
if use_side_inputs and input_type != 'image_tensor':
raise ValueError('Side inputs supported for image_tensor input type only.')
detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs,
side_input_shapes,
side_input_types,
side_input_names)
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
......
......@@ -106,6 +106,27 @@ flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '',
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
flags.DEFINE_boolean('use_side_inputs', False,
'If True, uses side inputs as well as image inputs.')
flags.DEFINE_string('side_input_shapes', "",
'If use_side_inputs is True, this explicitly sets '
'the shape of the side input tensors to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. A `/` denotes a break, starting the shape of '
'the next side input tensor. This flag is required if '
'using side inputs.')
flags.DEFINE_string('side_input_types', "",
'If use_side_inputs is True, this explicitly sets '
'the type of the side input tensors. The '
'dimensions are to be provided as a comma-separated list '
'of types, each of `string`, `integer`, or `float`. '
'This flag is required if using side inputs.')
flags.DEFINE_string('side_input_names', "",
'If use_side_inputs is True, this explicitly sets '
'the names of the side input tensors required by the model '
'assuming the names will be a comma-separated list of '
'strings. This flag is required if using side inputs.')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_dir')
......@@ -119,7 +140,8 @@ def main(_):
text_format.Merge(FLAGS.config_override, pipeline_config)
exporter_lib_v2.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_dir,
FLAGS.output_directory)
FLAGS.output_directory, FLAGS.use_side_inputs, FLAGS.side_input_shapes,
FLAGS.side_input_types, FLAGS.side_input_names)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment