Unverified Commit e46a2497 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

add option to generate panoptic masks

parent 3d09f146
...@@ -222,15 +222,20 @@ def build_panoptic_deeplab( ...@@ -222,15 +222,20 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
if model_config.generate_panoptic_masks:
post_processing_config = model_config.post_processor post_processing_config = model_config.post_processor
post_processor = panoptic_deeplab_merge.PostProcessor( post_processor = panoptic_deeplab_merge.PostProcessor(
output_size=post_processing_config.output_size,
center_score_threshold=post_processing_config.center_score_threshold, center_score_threshold=post_processing_config.center_score_threshold,
thing_class_ids=post_processing_config.thing_class_ids, thing_class_ids=post_processing_config.thing_class_ids,
label_divisor=post_processing_config.label_divisor, label_divisor=post_processing_config.label_divisor,
stuff_area_limit=post_processing_config.stuff_area_limit, stuff_area_limit=post_processing_config.stuff_area_limit,
ignore_label=post_processing_config.ignore_label, ignore_label=post_processing_config.ignore_label,
nms_kernel=post_processing_config.nms_kernel, nms_kernel=post_processing_config.nms_kernel,
keep_k_centers=post_processing_config.keep_k_centers) keep_k_centers=post_processing_config.keep_k_centers,
rescale_predictions=post_processing_config.rescale_predictions)
else:
post_processor = None
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
......
...@@ -73,9 +73,11 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -73,9 +73,11 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
decoder_type=['aspp', 'fpn'], decoder_type=['aspp', 'fpn'],
level=[2, 3, 4], level=[2, 3, 4],
low_level=[(4, 3), (3, 2)], low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False])) shared_decoder=[True, False],
def test_builder(self, input_size, backbone_type, level, generate_panoptic_masks=[True, False]))
low_level, decoder_type, shared_decoder): def test_builder(self, input_size, backbone_type,
level, low_level, decoder_type,
shared_decoder, generate_panoptic_masks):
num_classes = 10 num_classes = 10
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
...@@ -97,7 +99,8 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -97,7 +99,8 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
kernel_size=5, kernel_size=5,
prediction_kernel_size=1, prediction_kernel_size=1,
low_level=low_level), low_level=low_level),
shared_decoder=shared_decoder) shared_decoder=shared_decoder,
generate_panoptic_masks=generate_panoptic_masks)
l2_regularizer = tf.keras.regularizers.l2(5e-5) l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_panoptic_deeplab( _ = factory.build_panoptic_deeplab(
......
...@@ -57,7 +57,10 @@ class PanopticDeeplabModel(tf.keras.Model): ...@@ -57,7 +57,10 @@ class PanopticDeeplabModel(tf.keras.Model):
self.instance_head = instance_head self.instance_head = instance_head
self.post_processor = post_processor self.post_processor = post_processor
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor: def call(
self, inputs: tf.Tensor,
image_info: tf.Tensor,
training: bool = None) -> tf.Tensor:
if training is None: if training is None:
training = tf.keras.backend.learning_phase() training = tf.keras.backend.learning_phase()
...@@ -81,15 +84,18 @@ class PanopticDeeplabModel(tf.keras.Model): ...@@ -81,15 +84,18 @@ class PanopticDeeplabModel(tf.keras.Model):
outputs = { outputs = {
'segmentation_outputs': segmentation_outputs, 'segmentation_outputs': segmentation_outputs,
'instance_center_prediction': 'instance_centers_heatmap':
instance_outputs['instance_center_prediction'], instance_outputs['instance_centers_heatmap'],
'instance_center_regression': 'instance_centers_offset':
instance_outputs['instance_center_regression'], instance_outputs['instance_centers_offset'],
} }
if training: if training:
return outputs return outputs
outputs = self.post_processor(outputs) if self.post_processor is not None:
panoptic_masks = self.post_processor(outputs, image_info)
outputs.update(panoptic_masks)
return outputs return outputs
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment