Commit 5fdf878d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 422863802
parent 4ca209ff
...@@ -46,13 +46,13 @@ class SegmentationModule(export_base.ExportModule): ...@@ -46,13 +46,13 @@ class SegmentationModule(export_base.ExportModule):
offset=MEAN_RGB, offset=MEAN_RGB,
scale=STDDEV_RGB) scale=STDDEV_RGB)
image, _ = preprocess_ops.resize_and_crop_image( image, image_info = preprocess_ops.resize_and_crop_image(
image, image,
self._input_image_size, self._input_image_size,
padded_size=self._input_image_size, padded_size=self._input_image_size,
aug_scale_min=1.0, aug_scale_min=1.0,
aug_scale_max=1.0) aug_scale_max=1.0)
return image return image, image_info
def serve(self, images): def serve(self, images):
"""Cast image to float and run inference. """Cast image to float and run inference.
...@@ -64,21 +64,27 @@ class SegmentationModule(export_base.ExportModule): ...@@ -64,21 +64,27 @@ class SegmentationModule(export_base.ExportModule):
""" """
# Skip image preprocessing when input_type is tflite so it is compatible # Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization. # with TFLite quantization.
image_info = None
if self._input_type != 'tflite': if self._input_type != 'tflite':
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32) images = tf.cast(images, dtype=tf.float32)
images_spec = tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images = tf.nest.map_structure( images, image_info = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._build_inputs, self._build_inputs,
elems=images, elems=images,
fn_output_signature=tf.TensorSpec( fn_output_signature=(images_spec, image_info_spec),
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32)) parallel_iterations=32))
outputs = self.inference_step(images) outputs = self.inference_step(images)
outputs['logits'] = tf.image.resize( outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear') outputs['logits'], self._input_image_size, method='bilinear')
if image_info is not None:
outputs.update({'image_info': image_info})
return outputs return outputs
...@@ -93,13 +93,15 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -93,13 +93,15 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input(input_type) images = self._get_dummy_input(input_type)
if input_type != 'tflite': if input_type != 'tflite':
processed_images = tf.nest.map_structure( processed_images, _ = tf.nest.map_structure(
tf.stop_gradient, tf.stop_gradient,
tf.map_fn( tf.map_fn(
module._build_inputs, module._build_inputs,
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8), elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=(tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32))) shape=[112, 112, 3], dtype=tf.float32),
tf.TensorSpec(
shape=[4, 2], dtype=tf.float32))))
else: else:
processed_images = images processed_images = images
expected_output = tf.image.resize( expected_output = tf.image.resize(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment