Commit 024ebd81 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 414045177
parent 002b4ec4
...@@ -52,6 +52,18 @@ class DetectionModule(export_base.ExportModule): ...@@ -52,6 +52,18 @@ class DetectionModule(export_base.ExportModule):
return model return model
def _build_anchor_boxes(self):
"""Builds and returns anchor boxes."""
model_params = self.params.task.model
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(
image_size=(self._input_image_size[0], self._input_image_size[1]))
def _build_inputs(self, image): def _build_inputs(self, image):
"""Builds detection model inputs for serving.""" """Builds detection model inputs for serving."""
model_params = self.params.task.model model_params = self.params.task.model
...@@ -67,15 +79,7 @@ class DetectionModule(export_base.ExportModule): ...@@ -67,15 +79,7 @@ class DetectionModule(export_base.ExportModule):
self._input_image_size, 2**model_params.max_level), self._input_image_size, 2**model_params.max_level),
aug_scale_min=1.0, aug_scale_min=1.0,
aug_scale_max=1.0) aug_scale_max=1.0)
anchor_boxes = self._build_anchor_boxes()
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
anchor_boxes = input_anchor(image_size=(self._input_image_size[0],
self._input_image_size[1]))
return image, anchor_boxes, image_info return image, anchor_boxes, image_info
...@@ -133,7 +137,22 @@ class DetectionModule(export_base.ExportModule): ...@@ -133,7 +137,22 @@ class DetectionModule(export_base.ExportModule):
Tensor holding detection output logits. Tensor holding detection output logits.
""" """
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
images, anchor_boxes, image_info = self.preprocess(images) images, anchor_boxes, image_info = self.preprocess(images)
else:
with tf.device('cpu:0'):
anchor_boxes = self._build_anchor_boxes()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info = tf.convert_to_tensor([[
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
]],
dtype=tf.float32)
input_image_shape = image_info[:, 1, :] input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that # To overcome keras.Model extra limitation to save a model with layers that
......
...@@ -30,12 +30,15 @@ from official.vision.beta.serving import detection ...@@ -30,12 +30,15 @@ from official.vision.beta.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name): def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name) params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18 params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.nms_version = 'batched' params.task.model.detection_generator.nms_version = 'batched'
detection_module = detection.DetectionModule( detection_module = detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640]) params,
batch_size=1,
input_image_size=[640, 640],
input_type=input_type)
return detection_module return detection_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
...@@ -65,24 +68,30 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,24 +68,30 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])), bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString() })).SerializeToString()
return [example for b in range(batch_size)] return [example for b in range(batch_size)]
elif input_type == 'tflite':
return tf.zeros((batch_size, h, w, 3), dtype=np.float32)
@parameterized.parameters( @parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]), ('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]), ('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]), ('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]), ('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]), ('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]), ('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]), ('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]), ('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco', [384, 640]), ('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('tflite', 'retinanet_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]), ('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]), ('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]), ('tf_example', 'retinanet_spinenet_coco', [640, 384]),
('tflite', 'retinanet_spinenet_coco', [640, 640]),
) )
def test_export(self, input_type, experiment_name, image_size): def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name) module = self._get_detection_module(experiment_name, input_type)
self._export_from_module(module, input_type, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
...@@ -100,6 +109,12 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -100,6 +109,12 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input( images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size) input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite':
processed_images = tf.zeros(image_size + [3], dtype=tf.float32)
anchor_boxes = module._build_anchor_boxes()
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs( processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8)) tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :] image_shape = image_info[1, :]
......
...@@ -31,6 +31,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -31,6 +31,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
*, *,
batch_size: int, batch_size: int,
input_image_size: List[int], input_image_size: List[int],
input_type: str = 'image_tensor',
num_channels: int = 3, num_channels: int = 3,
model: Optional[tf.keras.Model] = None): model: Optional[tf.keras.Model] = None):
"""Initializes a module for export. """Initializes a module for export.
...@@ -40,6 +41,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -40,6 +41,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
batch_size: The batch size of the model input. Can be `int` or None. batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image, input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width]. it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels. num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
""" """
...@@ -47,6 +49,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -47,6 +49,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self._batch_size = batch_size self._batch_size = batch_size
self._input_image_size = input_image_size self._input_image_size = input_image_size
self._num_channels = num_channels self._num_channels = num_channels
self._input_type = input_type
if model is None: if model is None:
model = self._build_model() # pylint: disable=assignment-from-none model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model) super().__init__(params=params, model=model)
......
...@@ -89,6 +89,7 @@ def export_inference_graph( ...@@ -89,6 +89,7 @@ def export_inference_graph(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels) num_channels=num_channels)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask): params.task, configs.maskrcnn.MaskRCNNTask):
...@@ -96,6 +97,7 @@ def export_inference_graph( ...@@ -96,6 +97,7 @@ def export_inference_graph(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels) num_channels=num_channels)
elif isinstance(params.task, elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask): configs.semantic_segmentation.SemanticSegmentationTask):
...@@ -103,6 +105,7 @@ def export_inference_graph( ...@@ -103,6 +105,7 @@ def export_inference_graph(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels) num_channels=num_channels)
elif isinstance(params.task, elif isinstance(params.task,
configs.video_classification.VideoClassificationTask): configs.video_classification.VideoClassificationTask):
...@@ -110,6 +113,7 @@ def export_inference_graph( ...@@ -110,6 +113,7 @@ def export_inference_graph(
params=params, params=params,
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels) num_channels=num_channels)
else: else:
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
......
...@@ -30,18 +30,10 @@ from official.vision.beta.serving import semantic_segmentation as semantic_segme ...@@ -30,18 +30,10 @@ from official.vision.beta.serving import semantic_segmentation as semantic_segme
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
super().setUp() examples = [example] * num_samples
self._test_tfrecord_file = os.path.join(self.get_temp_dir(), tfexample_utils.dump_to_tfrecord(
'test.tfrecord') record_file=tfrecord_file, tf_examples=examples)
self._create_test_tfrecord(num_samples=50)
def _create_test_tfrecord(self, num_samples):
tfexample_utils.dump_to_tfrecord(self._test_tfrecord_file, [
tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=256, image_width=256)) for _ in range(num_samples)
])
def _export_from_module(self, module, input_type, saved_model_dir): def _export_from_module(self, module, input_type, saved_model_dir):
signatures = module.get_inference_signatures( signatures = module.get_inference_signatures(
...@@ -51,16 +43,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -51,16 +43,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['mobilenet_imagenet'], experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'], quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[224, 224]])) input_image_size=[[224, 224]]))
def test_export_tflite_image_classification(self, experiment, quant_type, def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size): input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule( module = image_classification_serving.ClassificationModule(
params=params, batch_size=1, input_image_size=input_image_size) params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
input_type='tflite', input_type='tflite',
...@@ -78,13 +79,26 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -78,13 +79,26 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine( combinations.combine(
experiment=['retinanet_mobile_coco'], experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'], quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]])) input_image_size=[[384, 384]]))
def test_export_tflite_detection(self, experiment, quant_type, def test_export_tflite_detection(self, experiment, quant_type,
input_image_size): input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule( module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size) params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
input_type='tflite', input_type='tflite',
...@@ -100,15 +114,27 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -100,15 +114,27 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
experiment=['seg_deeplabv3_pascal'], experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'], quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[512, 512]])) input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type, def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size): input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule( module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size) params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module( self._export_from_module(
module=module, module=module,
input_type='tflite', input_type='tflite',
......
...@@ -63,18 +63,20 @@ class ClassificationModule(export_base.ExportModule): ...@@ -63,18 +63,20 @@ class ClassificationModule(export_base.ExportModule):
Returns: Returns:
Tensor holding classification output logits. Tensor holding classification output logits.
""" """
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32) images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._build_inputs, elems=images, self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32), shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32 parallel_iterations=32))
)
)
logits = self.inference_step(images) logits = self.inference_step(images)
probs = tf.nn.softmax(logits) probs = tf.nn.softmax(logits)
......
...@@ -30,11 +30,14 @@ from official.vision.beta.serving import image_classification ...@@ -30,11 +30,14 @@ from official.vision.beta.serving import image_classification
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self): def _get_classification_module(self, input_type):
params = exp_factory.get_exp_config('resnet_imagenet') params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18 params.task.model.backbone.resnet.model_id = 18
classification_module = image_classification.ClassificationModule( classification_module = image_classification.ClassificationModule(
params, batch_size=1, input_image_size=[224, 224]) params,
batch_size=1,
input_image_size=[224, 224],
input_type=input_type)
return classification_module return classification_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
...@@ -65,15 +68,18 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,15 +68,18 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])), bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString() })).SerializeToString()
return [example] return [example]
elif input_type == 'tflite':
return tf.zeros((1, 224, 224, 3), dtype=np.float32)
@parameterized.parameters( @parameterized.parameters(
{'input_type': 'image_tensor'}, {'input_type': 'image_tensor'},
{'input_type': 'image_bytes'}, {'input_type': 'image_bytes'},
{'input_type': 'tf_example'}, {'input_type': 'tf_example'},
{'input_type': 'tflite'},
) )
def test_export(self, input_type='image_tensor'): def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_classification_module() module = self._get_classification_module(input_type)
# Test that the model restores any attrs that are trackable objects # Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules). # (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,)) module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
...@@ -90,6 +96,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -90,6 +96,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
classification_fn = imported.signatures['serving_default'] classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type) images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure( processed_images = tf.nest.map_structure(
tf.stop_gradient, tf.stop_gradient,
tf.map_fn( tf.map_fn(
...@@ -97,6 +104,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -97,6 +104,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8), elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32))) shape=[224, 224, 3], dtype=tf.float32)))
else:
processed_images = images
expected_logits = module.model(processed_images, training=False) expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits) expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images)) out = classification_fn(tf.constant(images))
......
...@@ -62,18 +62,20 @@ class SegmentationModule(export_base.ExportModule): ...@@ -62,18 +62,20 @@ class SegmentationModule(export_base.ExportModule):
Returns: Returns:
Tensor holding classification output logits. Tensor holding classification output logits.
""" """
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
with tf.device('cpu:0'): with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32) images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure( images = tf.nest.map_structure(
tf.identity, tf.identity,
tf.map_fn( tf.map_fn(
self._build_inputs, elems=images, self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32), shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32 parallel_iterations=32))
)
)
masks = self.inference_step(images) masks = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear') masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
......
...@@ -30,10 +30,13 @@ from official.vision.beta.serving import semantic_segmentation ...@@ -30,10 +30,13 @@ from official.vision.beta.serving import semantic_segmentation
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_segmentation_module(self): def _get_segmentation_module(self, input_type):
params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal') params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal')
segmentation_module = semantic_segmentation.SegmentationModule( segmentation_module = semantic_segmentation.SegmentationModule(
params, batch_size=1, input_image_size=[112, 112]) params,
batch_size=1,
input_image_size=[112, 112],
input_type=input_type)
return segmentation_module return segmentation_module
def _export_from_module(self, module, input_type, save_directory): def _export_from_module(self, module, input_type, save_directory):
...@@ -62,15 +65,18 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -62,15 +65,18 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])), bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString() })).SerializeToString()
return [example] return [example]
elif input_type == 'tflite':
return tf.zeros((1, 112, 112, 3), dtype=np.float32)
@parameterized.parameters( @parameterized.parameters(
{'input_type': 'image_tensor'}, {'input_type': 'image_tensor'},
{'input_type': 'image_bytes'}, {'input_type': 'image_bytes'},
{'input_type': 'tf_example'}, {'input_type': 'tf_example'},
{'input_type': 'tflite'},
) )
def test_export(self, input_type='image_tensor'): def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module() module = self._get_segmentation_module(input_type)
self._export_from_module(module, input_type, tmp_dir) self._export_from_module(module, input_type, tmp_dir)
...@@ -86,6 +92,7 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -86,6 +92,7 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
segmentation_fn = imported.signatures['serving_default'] segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type) images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure( processed_images = tf.nest.map_structure(
tf.stop_gradient, tf.stop_gradient,
tf.map_fn( tf.map_fn(
...@@ -93,6 +100,8 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -93,6 +100,8 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8), elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec( fn_output_signature=tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32))) shape=[112, 112, 3], dtype=tf.float32)))
else:
processed_images = images
expected_output = tf.image.resize( expected_output = tf.image.resize(
module.model(processed_images, training=False), [112, 112], module.model(processed_images, training=False), [112, 112],
method='bilinear') method='bilinear')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment