Commit f6bf56d7 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean and update exporter

parent a3ae1258
......@@ -194,7 +194,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
saved_model_path = os.path.join(output_directory, 'saved_model')
detect_fn = tf.saved_model.load(saved_model_path)
image = self.get_dummy_input(input_type)
detections = detect_fn.signatures['serving_default'](tf.constant(image))
detections = detect_fn(tf.constant(image))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
......@@ -232,8 +232,8 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
detect_fn_sig = detect_fn.signatures['serving_default']
image = tf.constant(self.get_dummy_input(input_type))
side_input = np.ones((1,), dtype=np.float32)
#detections_one = tf.saved_model.load(saved_model_path)(image, side_input)
detections = detect_fn_sig(input_tensor=image, side_inp=tf.constant(side_input))
detections = detect_fn_sig(input_tensor=image,
side_inp=tf.constant(side_input))
detection_fields = fields.DetectionResultFields
self.assertAllClose(detections[detection_fields.detection_boxes],
......
......@@ -36,15 +36,33 @@ def _decode_tf_example(tf_example_string_tensor):
image_tensor = tensor_dict[fields.InputDataFields.image]
return image_tensor
def _zip_side_inputs(side_input_shapes="",
side_input_types="",
side_input_names=""):
"""Zips the side inputs together.
Args:
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Returns:
a zipped list of side input tuples.
"""
side_input_shapes = list(map(lambda x: eval('[' + x + ']'),
side_input_shapes.split("/")))
side_input_types = list(map(eval, side_input_types.split(",")))
return zip(side_input_shapes,
side_input_types,
side_input_names.split(","))
class DetectionInferenceModule(tf.Module):
"""Detection Inference Module."""
def __init__(self, detection_model,
use_side_inputs=False,
side_input_shapes=None,
side_input_types=None,
side_input_names=None):
zipped_side_inputs=None:
"""Initializes a module for detection.
Args:
......@@ -75,27 +93,26 @@ class DetectionInferenceModule(tf.Module):
return detections
class DetectionFromImageModule(DetectionInferenceModule):
"""Detection Inference Module for image inputs."""
def __init__(self, detection_model,
use_side_inputs=False,
side_input_shapes="",
side_input_types="",
side_input_names=""):
zipped_side_inputs=None):
"""Initializes a module for detection.
Args:
detection_model: The detection model to use for inference.
"""
self.side_input_names = side_input_names
self.side_input_names = []
sig = [tf.TensorSpec(shape=[1, None, None, 3], dtype=tf.uint8)]
if use_side_inputs:
for info in zip(side_input_shapes.split("/"),
side_input_types.split(","),
side_input_names.split(",")):
sig.append(tf.TensorSpec(shape=eval("[" + info[0] + "]"),
dtype=eval(info[1]),
for info in zipped_side_inputs:
self.side_input_names.append(info[2])
sig.append(tf.TensorSpec(shape=info[0],
dtype=info[1],
name=info[2]))
def __call__(input_tensor, *side_inputs):
......@@ -105,9 +122,8 @@ class DetectionFromImageModule(DetectionInferenceModule):
self.__call__ = tf.function(__call__, input_signature=sig)
super(DetectionFromImageModule, self).__init__(detection_model,
side_input_shapes,
side_input_types,
side_input_names)
use_side_inputs,
zipped_side_inputs)
class DetectionFromFloatImageModule(DetectionInferenceModule):
......@@ -179,6 +195,12 @@ def export_inference_graph(input_type,
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
use_side_inputs: boolean that determines whether side inputs should be
included in the input signature.
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Raises:
ValueError: if input_type is invalid.
"""
......@@ -198,11 +220,12 @@ def export_inference_graph(input_type,
raise ValueError('Unrecognized `input_type`')
if use_side_inputs and input_type != 'image_tensor':
raise ValueError('Side inputs supported for image_tensor input type only.')
detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs,
side_input_shapes,
zipped_side_inputs = _zip_side_inputs(side_input_shapes,
side_input_types,
side_input_names)
detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs,
zipped_side_inputs)
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment