Commit 82f93573 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 358901400
parent 59e25c9d
......@@ -70,8 +70,6 @@ class DetectionModule(export_base.ExportModule):
aug_scale_min=1.0,
aug_scale_max=1.0)
image_shape = image_info[1, :] # Shape of original image.
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
......@@ -81,7 +79,7 @@ class DetectionModule(export_base.ExportModule):
anchor_boxes = input_anchor(image_size=(self._input_image_size[0],
self._input_image_size[1]))
return image, anchor_boxes, image_shape
return image, anchor_boxes, image_info
def _run_inference_on_image_tensors(self, images: tf.Tensor):
"""Cast image to float and run inference.
......@@ -111,20 +109,22 @@ class DetectionModule(export_base.ExportModule):
dtype=tf.float32)
anchor_shapes.append((str(level), anchor_level_spec))
image_shape_spec = tf.TensorSpec(shape=[2,], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, anchor_boxes, image_shape = tf.nest.map_structure(
images, anchor_boxes, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, dict(anchor_shapes),
image_shape_spec),
image_info_spec),
parallel_iterations=32))
input_image_shape = image_info[:, 1, :]
detections = self._model.call(
images=images,
image_shape=image_shape,
image_shape=input_image_shape,
anchor_boxes=anchor_boxes,
training=False)
......@@ -132,7 +132,8 @@ class DetectionModule(export_base.ExportModule):
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
'num_detections': detections['num_detections'],
'image_info': image_info
}
if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks']
......
......@@ -125,10 +125,11 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input(input_type, batch_size, image_size)
processed_images, anchor_boxes, image_shape = module._build_inputs(
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
processed_images = tf.expand_dims(processed_images, 0)
image_shape = image_info[1, :]
image_shape = tf.expand_dims(image_shape, 0)
processed_images = tf.expand_dims(processed_images, 0)
for l, l_boxes in anchor_boxes.items():
anchor_boxes[l] = tf.expand_dims(l_boxes, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment