Commit f9ff935a authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

exporter changes

parent 26ba72bb
...@@ -39,30 +39,30 @@ def _decode_tf_example(tf_example_string_tensor): ...@@ -39,30 +39,30 @@ def _decode_tf_example(tf_example_string_tensor):
def _zip_side_inputs(side_input_shapes="", def _zip_side_inputs(side_input_shapes="",
side_input_types="", side_input_types="",
side_input_names=""): side_input_names=""):
"""Zips the side inputs together. """Zips the side inputs together.
Args: Args:
side_input_shapes: forward-slash-separated list of comma-separated lists side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes. describing input shapes.
side_input_types: comma-separated list of the types of the inputs. side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs. side_input_names: comma-separated list of the names of the inputs.
Returns: Returns:
a zipped list of side input tuples. a zipped list of side input tuples.
""" """
side_input_shapes = list(map(lambda x: eval('[' + x + ']'), side_input_shapes = list(map(lambda x: eval('[' + x + ']'),
side_input_shapes.split("/"))) side_input_shapes.split("/")))
side_input_types = list(map(eval, side_input_types.split(","))) side_input_types = list(map(eval, side_input_types.split(",")))
return zip(side_input_shapes, return zip(side_input_shapes,
side_input_types, side_input_types,
side_input_names.split(",")) side_input_names.split(","))
class DetectionInferenceModule(tf.Module): class DetectionInferenceModule(tf.Module):
"""Detection Inference Module.""" """Detection Inference Module."""
def __init__(self, detection_model, def __init__(self, detection_model,
use_side_inputs=False, use_side_inputs=False,
zipped_side_inputs=None: zipped_side_inputs=None):
"""Initializes a module for detection. """Initializes a module for detection.
Args: Args:
...@@ -116,7 +116,7 @@ class DetectionFromImageModule(DetectionInferenceModule): ...@@ -116,7 +116,7 @@ class DetectionFromImageModule(DetectionInferenceModule):
name=info[2])) name=info[2]))
def __call__(input_tensor, *side_inputs): def __call__(input_tensor, *side_inputs):
kwargs = dict(zip(self.side_input_names.split(","), side_inputs)) kwargs = dict(zip(self.side_input_names, side_inputs))
return self._run_inference_on_images(input_tensor, **kwargs) return self._run_inference_on_images(input_tensor, **kwargs)
self.__call__ = tf.function(__call__, input_signature=sig) self.__call__ = tf.function(__call__, input_signature=sig)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment