Commit 5fdf878d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 422863802
parent 4ca209ff
......@@ -46,13 +46,13 @@ class SegmentationModule(export_base.ExportModule):
offset=MEAN_RGB,
scale=STDDEV_RGB)
image, _ = preprocess_ops.resize_and_crop_image(
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
return image
return image, image_info
def serve(self, images):
"""Cast image to float and run inference.
......@@ -64,21 +64,27 @@ class SegmentationModule(export_base.ExportModule):
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
image_info = None
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images_spec = tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images = tf.nest.map_structure(
images, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
fn_output_signature=(images_spec, image_info_spec),
parallel_iterations=32))
outputs = self.inference_step(images)
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
if image_info is not None:
outputs.update({'image_info': image_info})
return outputs
......@@ -93,13 +93,15 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure(
processed_images, _ = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._build_inputs,
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32)))
fn_output_signature=(tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32),
tf.TensorSpec(
shape=[4, 2], dtype=tf.float32))))
else:
processed_images = images
expected_output = tf.image.resize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment