Unverified Commit 09d9656f authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents ac671306 49a5706c
......@@ -27,7 +27,6 @@ from official.vision.beta.projects.video_ssl.configs import video_ssl as video_s
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class VideoSSLModel(tf.keras.Model):
"""A video ssl model class builder."""
......
......@@ -21,6 +21,7 @@ from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling import nn_blocks
layers = tf.keras.layers
VIT_SPECS = {
......@@ -121,6 +122,7 @@ class Encoder(tf.keras.layers.Layer):
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
......@@ -132,8 +134,10 @@ class Encoder(tf.keras.layers.Layer):
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
def build(self, input_shape):
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
name='posembed_input')
......@@ -160,7 +164,9 @@ class Encoder(tf.keras.layers.Layer):
super().build(input_shape)
def call(self, inputs, training=None):
x = self._pos_embed(inputs, inputs_positions=self._inputs_positions)
x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
......
......@@ -371,7 +371,6 @@ BACKBONES = {
}
@tf.keras.utils.register_keras_serializable(package='yolo')
class Darknet(tf.keras.Model):
"""The Darknet backbone architecture."""
......
......@@ -84,14 +84,12 @@ YOLO_MODELS = {
}
@tf.keras.utils.register_keras_serializable(package='yolo')
class _IdentityRoute(tf.keras.layers.Layer):
def call(self, inputs): # pylint: disable=arguments-differ
return None, inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloFPN(tf.keras.layers.Layer):
"""YOLO Feature pyramid network."""
......@@ -248,7 +246,6 @@ class YoloFPN(tf.keras.layers.Layer):
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloPAN(tf.keras.layers.Layer):
"""YOLO Path Aggregation Network."""
......@@ -441,7 +438,6 @@ class YoloPAN(tf.keras.layers.Layer):
return outputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloDecoder(tf.keras.Model):
"""Darknet Backbone Decoder."""
......
......@@ -21,7 +21,6 @@ from official.vision.beta.projects.yolo.ops import box_ops
from official.vision.beta.projects.yolo.ops import loss_utils
@tf.keras.utils.register_keras_serializable(package='yolo')
class YoloLayer(tf.keras.Model):
"""Yolo layer (detection generator)."""
......
......@@ -21,14 +21,12 @@ from official.modeling import tf_utils
from official.vision.beta.ops import spatial_transform_ops
@tf.keras.utils.register_keras_serializable(package='yolo')
class Identity(tf.keras.layers.Layer):
def call(self, inputs):
return inputs
@tf.keras.utils.register_keras_serializable(package='yolo')
class ConvBN(tf.keras.layers.Layer):
"""ConvBN block.
......@@ -241,7 +239,6 @@ class ConvBN(tf.keras.layers.Layer):
return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkResidual(tf.keras.layers.Layer):
"""Darknet block with Residual connection for Yolo v3 Backbone."""
......@@ -406,7 +403,6 @@ class DarkResidual(tf.keras.layers.Layer):
return layer_config
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPTiny(tf.keras.layers.Layer):
"""CSP Tiny layer.
......@@ -556,7 +552,6 @@ class CSPTiny(tf.keras.layers.Layer):
return x, x5
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPRoute(tf.keras.layers.Layer):
"""CSPRoute block.
......@@ -696,7 +691,6 @@ class CSPRoute(tf.keras.layers.Layer):
return (x, y)
@tf.keras.utils.register_keras_serializable(package='yolo')
class CSPConnect(tf.keras.layers.Layer):
"""CSPConnect block.
......@@ -941,7 +935,6 @@ class CSPStack(tf.keras.layers.Layer):
return x
@tf.keras.utils.register_keras_serializable(package='yolo')
class PathAggregationBlock(tf.keras.layers.Layer):
"""Path Aggregation block."""
......@@ -1132,7 +1125,6 @@ class PathAggregationBlock(tf.keras.layers.Layer):
return self._call_regular(inputs, training=training)
@tf.keras.utils.register_keras_serializable(package='yolo')
class SPP(tf.keras.layers.Layer):
"""Spatial Pyramid Pooling.
......@@ -1411,7 +1403,6 @@ class CBAM(tf.keras.layers.Layer):
return self._sam(self._cam(inputs))
@tf.keras.utils.register_keras_serializable(package='yolo')
class DarkRouteProcess(tf.keras.layers.Layer):
"""Dark Route Process block.
......
......@@ -401,7 +401,6 @@ class YoloTask(base_task.Task):
use_float16 = runtime_config.mixed_precision_dtype == 'float16'
optimizer = performance.configure_optimizer(
optimizer,
use_graph_rewrite=False,
use_float16=use_float16,
loss_scale=runtime_config.loss_scale)
......
......@@ -52,6 +52,18 @@ class DetectionModule(export_base.ExportModule):
return model
def _build_anchor_boxes(self):
"""Builds and returns anchor boxes."""
model_params = self.params.task.model
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(
image_size=(self._input_image_size[0], self._input_image_size[1]))
def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
model_params = self.params.task.model
......@@ -67,15 +79,7 @@ class DetectionModule(export_base.ExportModule):
self._input_image_size, 2**model_params.max_level),
aug_scale_min=1.0,
aug_scale_max=1.0)
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
anchor_boxes = input_anchor(image_size=(self._input_image_size[0],
self._input_image_size[1]))
anchor_boxes = self._build_anchor_boxes()
return image, anchor_boxes, image_info
......@@ -133,7 +137,22 @@ class DetectionModule(export_base.ExportModule):
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
images, anchor_boxes, image_info = self.preprocess(images)
else:
with tf.device('cpu:0'):
anchor_boxes = self._build_anchor_boxes()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info = tf.convert_to_tensor([[
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
]],
dtype=tf.float32)
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that
......
......@@ -30,12 +30,15 @@ from official.vision.beta.serving import detection
class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_detection_module(self, experiment_name):
def _get_detection_module(self, experiment_name, input_type):
params = exp_factory.get_exp_config(experiment_name)
params.task.model.backbone.resnet.model_id = 18
params.task.model.detection_generator.nms_version = 'batched'
detection_module = detection.DetectionModule(
params, batch_size=1, input_image_size=[640, 640])
params,
batch_size=1,
input_image_size=[640, 640],
input_type=input_type)
return detection_module
def _export_from_module(self, module, input_type, save_directory):
......@@ -65,24 +68,30 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example for b in range(batch_size)]
elif input_type == 'tflite':
return tf.zeros((batch_size, h, w, 3), dtype=np.float32)
@parameterized.parameters(
('image_tensor', 'fasterrcnn_resnetfpn_coco', [384, 384]),
('image_bytes', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tf_example', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'fasterrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_bytes', 'maskrcnn_resnetfpn_coco', [640, 384]),
('tf_example', 'maskrcnn_resnetfpn_coco', [640, 640]),
('tflite', 'maskrcnn_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [640, 640]),
('image_bytes', 'retinanet_resnetfpn_coco', [640, 640]),
('tf_example', 'retinanet_resnetfpn_coco', [384, 640]),
('tflite', 'retinanet_resnetfpn_coco', [640, 640]),
('image_tensor', 'retinanet_resnetfpn_coco', [384, 384]),
('image_bytes', 'retinanet_spinenet_coco', [640, 640]),
('tf_example', 'retinanet_spinenet_coco', [640, 384]),
('tflite', 'retinanet_spinenet_coco', [640, 640]),
)
def test_export(self, input_type, experiment_name, image_size):
tmp_dir = self.get_temp_dir()
module = self._get_detection_module(experiment_name)
module = self._get_detection_module(experiment_name, input_type)
self._export_from_module(module, input_type, tmp_dir)
......@@ -100,6 +109,12 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
images = self._get_dummy_input(
input_type, batch_size=1, image_size=image_size)
if input_type == 'tflite':
processed_images = tf.zeros(image_size + [3], dtype=tf.float32)
anchor_boxes = module._build_anchor_boxes()
image_info = tf.convert_to_tensor(
[image_size, image_size, [1.0, 1.0], [0, 0]], dtype=tf.float32)
else:
processed_images, anchor_boxes, image_info = module._build_inputs(
tf.zeros((224, 224, 3), dtype=tf.uint8))
image_shape = image_info[1, :]
......
......@@ -31,6 +31,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
*,
batch_size: int,
input_image_size: List[int],
input_type: str = 'image_tensor',
num_channels: int = 3,
model: Optional[tf.keras.Model] = None):
"""Initializes a module for export.
......@@ -40,6 +41,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
batch_size: The batch size of the model input. Can be `int` or None.
input_image_size: List or Tuple of size of the input image. For 2D image,
it is [height, width].
input_type: The input signature type.
num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported.
"""
......@@ -47,6 +49,7 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
self._batch_size = batch_size
self._input_image_size = input_image_size
self._num_channels = num_channels
self._input_type = input_type
if model is None:
model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model)
......
......@@ -89,6 +89,7 @@ def export_inference_graph(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask):
......@@ -96,6 +97,7 @@ def export_inference_graph(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
......@@ -103,6 +105,7 @@ def export_inference_graph(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
elif isinstance(params.task,
configs.video_classification.VideoClassificationTask):
......@@ -110,6 +113,7 @@ def export_inference_graph(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
input_type=input_type,
num_channels=num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
......@@ -130,7 +134,9 @@ def export_inference_graph(
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(params.task, configs.retinanet.RetinaNetTask):
if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
......
......@@ -26,26 +26,43 @@ from official.vision.beta.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase):
@mock.patch.object(export_base, 'export', autospec=True, spec_set=True)
def test_retinanet_task(self, unused_export):
tempdir = self.create_tempdir()
params = configs.retinanet.retinanet_resnetfpn_coco()
print(params.task.model.backbone)
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
def setUp(self):
super().setUp()
self.tempdir = self.create_tempdir()
self.enter_context(
mock.patch.object(export_base, 'export', autospec=True, spec_set=True))
def _export_model_with_log_model_flops_and_params(self, params):
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[64, 64],
params=params,
checkpoint_path=os.path.join(tempdir, 'unused-ckpt'),
export_dir=tempdir,
checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
export_dir=self.tempdir,
log_model_flops_and_params=True)
def assertModelAnalysisFilesExist(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt')))
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_params.txt')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt')))
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_flops.txt')))
def test_retinanet_task(self):
params = configs.retinanet.retinanet_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
def test_maskrcnn_task(self):
params = configs.maskrcnn.maskrcnn_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
if __name__ == '__main__':
......
......@@ -43,6 +43,8 @@ def create_representative_dataset(
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task, configs.maskrcnn.MaskRCNNTask):
task = tasks.maskrcnn.MaskRCNNTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
......
......@@ -30,18 +30,10 @@ from official.vision.beta.serving import semantic_segmentation as semantic_segme
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._test_tfrecord_file = os.path.join(self.get_temp_dir(),
'test.tfrecord')
self._create_test_tfrecord(num_samples=50)
def _create_test_tfrecord(self, num_samples):
tfexample_utils.dump_to_tfrecord(self._test_tfrecord_file, [
tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=256, image_width=256)) for _ in range(num_samples)
])
def _create_test_tfrecord(self, tfrecord_file, example, num_samples):
examples = [example] * num_samples
tfexample_utils.dump_to_tfrecord(
record_file=tfrecord_file, tf_examples=examples)
def _export_from_module(self, module, input_type, saved_model_dir):
signatures = module.get_inference_signatures(
......@@ -51,16 +43,25 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[224, 224]]))
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
example = tf.train.Example.FromString(
tfexample_utils.create_classification_example(
image_height=input_image_size[0], image_width=input_image_size[1]))
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = image_classification_serving.ClassificationModule(
params=params, batch_size=1, input_image_size=input_image_size)
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
......@@ -78,13 +79,26 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]]))
input_image_size=[[384, 384]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'det_test.tfrecord')
example = tfexample_utils.create_detection_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3,
num_instances=10)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size)
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
......@@ -100,15 +114,27 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
experiment=['seg_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'],
experiment=['mnv2_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16', 'int8', 'int8_full'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'seg_test.tfrecord')
example = tfexample_utils.create_segmentation_test_example(
image_height=input_image_size[0],
image_width=input_image_size[1],
image_channel=3)
self._create_test_tfrecord(
tfrecord_file=test_tfrecord_file, example=example, num_samples=10)
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = test_tfrecord_file
params.task.train_data.input_path = test_tfrecord_file
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size)
params=params,
batch_size=1,
input_image_size=input_image_size,
input_type='tflite')
self._export_from_module(
module=module,
input_type='tflite',
......
......@@ -63,18 +63,20 @@ class ClassificationModule(export_base.ExportModule):
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
parallel_iterations=32))
logits = self.inference_step(images)
probs = tf.nn.softmax(logits)
......
......@@ -30,11 +30,14 @@ from official.vision.beta.serving import image_classification
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self):
def _get_classification_module(self, input_type):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
classification_module = image_classification.ClassificationModule(
params, batch_size=1, input_image_size=[224, 224])
params,
batch_size=1,
input_image_size=[224, 224],
input_type=input_type)
return classification_module
def _export_from_module(self, module, input_type, save_directory):
......@@ -65,15 +68,18 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
elif input_type == 'tflite':
return tf.zeros((1, 224, 224, 3), dtype=np.float32)
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
{'input_type': 'tflite'},
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
module = self._get_classification_module(input_type)
# Test that the model restores any attrs that are trackable objects
# (eg: tables, resource variables, keras models/layers, tf.hub modules).
module.model.test_trackable = tf.keras.layers.InputLayer(input_shape=(4,))
......@@ -90,6 +96,7 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
......@@ -97,6 +104,8 @@ class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32)))
else:
processed_images = images
expected_logits = module.model(processed_images, training=False)
expected_prob = tf.nn.softmax(expected_logits)
out = classification_fn(tf.constant(images))
......
......@@ -62,20 +62,23 @@ class SegmentationModule(export_base.ExportModule):
Returns:
Tensor holding classification output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
self._build_inputs,
elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
parallel_iterations=32))
masks = self.inference_step(images)
masks = tf.image.resize(masks, self._input_image_size, method='bilinear')
outputs = self.inference_step(images)
outputs['logits'] = tf.image.resize(
outputs['logits'], self._input_image_size, method='bilinear')
return dict(predicted_masks=masks)
return outputs
......@@ -30,11 +30,13 @@ from official.vision.beta.serving import semantic_segmentation
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_segmentation_module(self):
params = exp_factory.get_exp_config('seg_deeplabv3_pascal')
params.task.model.backbone.dilated_resnet.model_id = 50
def _get_segmentation_module(self, input_type):
params = exp_factory.get_exp_config('mnv2_deeplabv3_pascal')
segmentation_module = semantic_segmentation.SegmentationModule(
params, batch_size=1, input_image_size=[112, 112])
params,
batch_size=1,
input_image_size=[112, 112],
input_type=input_type)
return segmentation_module
def _export_from_module(self, module, input_type, save_directory):
......@@ -63,15 +65,18 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
elif input_type == 'tflite':
return tf.zeros((1, 112, 112, 3), dtype=np.float32)
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
{'input_type': 'tflite'},
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module()
module = self._get_segmentation_module(input_type)
self._export_from_module(module, input_type, tmp_dir)
......@@ -87,6 +92,7 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
if input_type != 'tflite':
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
......@@ -94,11 +100,13 @@ class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
elems=tf.zeros((1, 112, 112, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[112, 112, 3], dtype=tf.float32)))
else:
processed_images = images
expected_output = tf.image.resize(
module.model(processed_images, training=False), [112, 112],
module.model(processed_images, training=False)['logits'], [112, 112],
method='bilinear')
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['predicted_masks'].numpy(), expected_output.numpy())
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
if __name__ == '__main__':
......
......@@ -275,7 +275,9 @@ class MaskRCNNTask(base_task.Task):
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.task_config.model.include_mask, annotation_path,
regenerate_source_id=self._task_config.validation_data.decoder
.simple_decoder.regenerate_source_id)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
......
......@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype=params.dtype,
match_threshold=params.parser.match_threshold,
unmatched_threshold=params.parser.unmatched_threshold,
aug_type=params.parser.aug_type,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment