"vscode:/vscode.git/clone" did not exist on "09b43b1dfefcd987b1cc886c5e4c3d717587b74b"
Commit d2f47b86 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add ExportConfig in SemanticSegmentationTask config to add option to rescale...

Add ExportConfig in SemanticSegmentationTask config to add option to rescale predicted mask to the original image size.

PiperOrigin-RevId: 445061309
parent 84303a2b
......@@ -116,6 +116,12 @@ class Evaluation(hyperparams.Config):
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass
class ExportConfig(hyperparams.Config):
# Whether to rescale the predicted mask to the original image size.
rescale_output: bool = False
@dataclasses.dataclass
class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config."""
......@@ -131,6 +137,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
export_config: ExportConfig = ExportConfig()
@exp_factory.register_config_factory('semantic_segmentation')
......
......@@ -45,12 +45,16 @@ class SegmentationModule(export_base.ExportModule):
offset=MEAN_RGB,
scale=STDDEV_RGB)
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
if self.params.task.train_data.preserve_aspect_ratio:
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._input_image_size,
aug_scale_min=1.0,
aug_scale_max=1.0)
else:
image, image_info = preprocess_ops.resize_image(image,
self._input_image_size)
return image, image_info
def serve(self, images):
......@@ -80,8 +84,27 @@ class SegmentationModule(export_base.ExportModule):
parallel_iterations=32))
outputs = self.inference_step(images)
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
# Optionally resize prediction to the input image size.
if self.params.task.export_config.rescale_output:
logits = outputs['logits']
if logits.shape[0] != 1:
raise ValueError('Batch size cannot be more than 1.')
image_shape = tf.cast(image_info[0, 0, :], tf.int32)
if self.params.task.train_data.preserve_aspect_ratio:
rescale_size = tf.cast(
tf.math.ceil(image_info[0, 1, :] / image_info[0, 2, :]), tf.int32)
offsets = tf.cast(image_info[0, 3, :], tf.int32)
logits = tf.image.resize(logits, rescale_size, method='bilinear')
outputs['logits'] = tf.image.crop_to_bounding_box(
logits, offsets[0], offsets[1], image_shape[0], image_shape[1])
else:
outputs['logits'] = tf.image.resize(
logits, [image_shape[0], image_shape[1]], method='bilinear')
else:
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
if image_info is not None:
outputs.update({'image_info': image_info})
......
......@@ -29,11 +29,17 @@ from official.vision.serving import semantic_segmentation
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_segmentation_module(self, input_type):
def _get_segmentation_module(self,
input_type,
rescale_output,
preserve_aspect_ratio,
batch_size=1):
params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal')
params.task.export_config.rescale_output = rescale_output
params.task.train_data.preserve_aspect_ratio = preserve_aspect_ratio
segmentation_module = semantic_segmentation.SegmentationModule(
params,
batch_size=1,
batch_size=batch_size,
input_image_size=[112, 112],
input_type=input_type)
return segmentation_module
......@@ -43,18 +49,20 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
{input_type: 'serving_default'})
tf.saved_model.save(module, save_directory, signatures=signatures)
def _get_dummy_input(self, input_type):
def _get_dummy_input(self, input_type, input_image_size):
"""Get dummy input for the given input type."""
height = input_image_size[0]
width = input_image_size[1]
if input_type == 'image_tensor':
return tf.zeros((1, 112, 112, 3), dtype=np.uint8)
return tf.zeros((1, height, width, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((112, 112, 3), dtype=np.uint8))
image = Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((112, 112, 3), dtype=tf.uint8)
image_tensor = tf.zeros((height, width, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
......@@ -65,17 +73,24 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
})).SerializeToString()
return [example]
elif input_type == 'tflite':
return tf.zeros((1, 112, 112, 3), dtype=np.float32)
return tf.zeros((1, height, width, 3), dtype=np.float32)
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
{'input_type': 'tflite'},
('image_tensor', False, [112, 112], False),
('image_bytes', False, [112, 112], False),
('tf_example', False, [112, 112], True),
('tflite', False, [112, 112], False),
('image_tensor', True, [112, 56], True),
('image_bytes', True, [112, 56], True),
('tf_example', True, [56, 112], False),
)
def test_export(self, input_type='image_tensor'):
def test_export(self, input_type, rescale_output, input_image_size,
preserve_aspect_ratio):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module(input_type)
module = self._get_segmentation_module(
input_type=input_type,
rescale_output=rescale_output,
preserve_aspect_ratio=preserve_aspect_ratio)
self._export_from_module(module, input_type, tmp_dir)
......@@ -90,7 +105,7 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
images = self._get_dummy_input(input_type, input_image_size)
if input_type != 'tflite':
processed_images, _ = tf.nest.map_structure(
tf.stop_gradient,
......@@ -103,12 +118,28 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
shape=[4, 2], dtype=tf.float32))))
else:
processed_images = images
expected_output = tf.image.resize(
module.model(processed_images, training=False)['logits'], [112, 112],
method='bilinear')
logits = module.model(processed_images, training=False)['logits']
if rescale_output:
expected_output = tf.image.resize(
logits, input_image_size, method='bilinear')
else:
expected_output = tf.image.resize(logits, [112, 112], method='bilinear')
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
def test_export_invalid_batch_size(self):
batch_size = 3
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module(
input_type='image_tensor',
rescale_output=True,
preserve_aspect_ratio=False,
batch_size=batch_size)
with self.assertRaisesRegex(ValueError,
'Batch size cannot be more than 1.'):
self._export_from_module(module, 'image_tensor', tmp_dir)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment