Commit d08a1c66 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix exporters

parent d73569c5
...@@ -21,7 +21,7 @@ from object_detection.builders import model_builder ...@@ -21,7 +21,7 @@ from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util from object_detection.utils import config_util
import ast
def _decode_image(encoded_image_string_tensor): def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor, image_tensor = tf.image.decode_image(encoded_image_string_tensor,
...@@ -50,12 +50,11 @@ def _zip_side_inputs(side_input_shapes="", ...@@ -50,12 +50,11 @@ def _zip_side_inputs(side_input_shapes="",
Returns: Returns:
a zipped list of side input tuples. a zipped list of side input tuples.
""" """
if (side_input_shapes) side_input_shapes = list(map(lambda x: ast.literal_eval('[' + x + ']'),
side_input_shapes = list(map(lambda x: eval('[' + x + ']'), side_input_shapes.split("/"))) side_input_shapes.split('/')))
side_input_types = map(eval, side_input_types.split(",")) side_input_types = eval('[' + side_input_types + ']')
print(list(side_input_types)) side_input_names = side_input_names.split(',')
#side_input_types = list(map(eval, side_input_types.split(","))) return zip(side_input_shapes, side_input_types, side_input_names)
return zip(side_input_shapes, side_input_types, side_input_names.split(","))
class DetectionInferenceModule(tf.Module): class DetectionInferenceModule(tf.Module):
"""Detection Inference Module.""" """Detection Inference Module."""
...@@ -220,9 +219,12 @@ def export_inference_graph(input_type, ...@@ -220,9 +219,12 @@ def export_inference_graph(input_type,
raise ValueError('Unrecognized `input_type`') raise ValueError('Unrecognized `input_type`')
if use_side_inputs and input_type != 'image_tensor': if use_side_inputs and input_type != 'image_tensor':
raise ValueError('Side inputs supported for image_tensor input type only.') raise ValueError('Side inputs supported for image_tensor input type only.')
zipped_side_inputs = _zip_side_inputs(side_input_shapes,
side_input_types, zipped_side_inputs = None
side_input_names) if use_side_inputs:
zipped_side_inputs = _zip_side_inputs(side_input_shapes,
side_input_types,
side_input_names)
detection_module = DETECTION_MODULE_MAP[input_type](detection_model, detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs, use_side_inputs,
zipped_side_inputs) zipped_side_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment