Unverified Commit 4ace44be authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added post processing layer

parent 2d739bb8
......@@ -26,6 +26,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_de
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
def build_panoptic_maskrcnn(
......@@ -220,11 +221,22 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
post_processing_config = model_config.post_processor
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=post_processing_config.center_score_threshold,
thing_class_ids=post_processing_config.thing_class_ids,
label_divisor=post_processing_config.label_divisor,
stuff_area_limit=post_processing_config.stuff_area_limit,
ignore_label=post_processing_config.ignore_label,
nms_kernel=post_processing_config.nms_kernel,
keep_k_centers=post_processing_config.keep_k_centers)
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder,
semantic_head=semantic_head,
instance_head=instance_head)
instance_head=instance_head,
post_processor=post_processor)
return model
......@@ -16,7 +16,7 @@
from typing import Any, Mapping, Optional, Union
import tensorflow as tf
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticDeeplabModel(tf.keras.Model):
......@@ -29,6 +29,7 @@ class PanopticDeeplabModel(tf.keras.Model):
semantic_head: tf.keras.layers.Layer,
instance_head: tf.keras.layers.Layer,
instance_decoder: Optional[tf.keras.Model] = None,
post_processor: Optional[panoptic_deeplab_merge.PostProcessor] = None,
**kwargs):
"""
Args:
......@@ -46,13 +47,15 @@ class PanopticDeeplabModel(tf.keras.Model):
'semantic_decoder': semantic_decoder,
'instance_decoder': instance_decoder,
'semantic_head': semantic_head,
'instance_head': instance_head
'instance_head': instance_head,
'post_processor': post_processor
}
self.backbone = backbone
self.semantic_decoder = semantic_decoder
self.instance_decoder = instance_decoder
self.semantic_head = semantic_head
self.instance_head = instance_head
self.post_processor = post_processor
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
if training is None:
......@@ -83,6 +86,10 @@ class PanopticDeeplabModel(tf.keras.Model):
'instance_center_regression':
instance_outputs['instance_center_regression'],
}
if training:
return outputs
outputs = self.post_processor(outputs)
return outputs
@property
......
......@@ -24,6 +24,7 @@ from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
......@@ -37,8 +38,9 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
def test_panoptic_deeplab_network_creation(
self, input_size, level, low_level, shared_decoder, training):
"""Test for creation of a panoptic deep lab network."""
batch_size = 2 if training else 1
num_classes = 10
inputs = np.random.rand(2, input_size, input_size, 3)
inputs = np.random.rand(batch_size, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet(model_id=50)
......@@ -62,16 +64,26 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
low_level=low_level,
low_level_num_filters=(64, 32))
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=0.1,
thing_class_ids=[1, 2, 3, 4],
label_divisor=[256],
stuff_area_limit=4096,
ignore_label=0,
nms_kernel=41,
keep_k_centers=41)
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder,
semantic_head=semantic_head,
instance_head=instance_head)
instance_head=instance_head,
post_processor=post_processor)
outputs = model(inputs, training=training)
if training:
self.assertIn('segmentation_outputs', outputs)
self.assertIn('instance_center_prediction', outputs)
self.assertIn('instance_center_regression', outputs)
......@@ -92,6 +104,13 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
2],
outputs['instance_center_regression'].numpy().shape)
else:
self.assertIn('panoptic_outputs', outputs)
self.assertIn('category_mask', outputs)
self.assertIn('instance_mask', outputs)
self.assertIn('instance_centers', outputs)
self.assertIn('instance_scores', outputs)
@combinations.generate(
combinations.combine(
level=[2, 3, 4],
......@@ -122,12 +141,22 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
low_level=low_level,
low_level_num_filters=(64, 32))
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=0.1,
thing_class_ids=[1, 2, 3, 4],
label_divisor=[256],
stuff_area_limit=4096,
ignore_label=0,
nms_kernel=41,
keep_k_centers=41)
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder,
semantic_head=semantic_head,
instance_head=instance_head)
instance_head=instance_head,
post_processor=post_processor)
config = model.get_config()
new_model = panoptic_deeplab_model.PanopticDeeplabModel.from_config(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment