Unverified Commit 4ace44be authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added post processing layer

parent 2d739bb8
...@@ -26,6 +26,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_de ...@@ -26,6 +26,7 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_de
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
def build_panoptic_maskrcnn( def build_panoptic_maskrcnn(
...@@ -220,11 +221,22 @@ def build_panoptic_deeplab( ...@@ -220,11 +221,22 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
post_processing_config = model_config.post_processor
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=post_processing_config.center_score_threshold,
thing_class_ids=post_processing_config.thing_class_ids,
label_divisor=post_processing_config.label_divisor,
stuff_area_limit=post_processing_config.stuff_area_limit,
ignore_label=post_processing_config.ignore_label,
nms_kernel=post_processing_config.nms_kernel,
keep_k_centers=post_processing_config.keep_k_centers)
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
semantic_decoder=semantic_decoder, semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder, instance_decoder=instance_decoder,
semantic_head=semantic_head, semantic_head=semantic_head,
instance_head=instance_head) instance_head=instance_head,
post_processor=post_processor)
return model return model
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from typing import Any, Mapping, Optional, Union from typing import Any, Mapping, Optional, Union
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticDeeplabModel(tf.keras.Model): class PanopticDeeplabModel(tf.keras.Model):
...@@ -29,6 +29,7 @@ class PanopticDeeplabModel(tf.keras.Model): ...@@ -29,6 +29,7 @@ class PanopticDeeplabModel(tf.keras.Model):
semantic_head: tf.keras.layers.Layer, semantic_head: tf.keras.layers.Layer,
instance_head: tf.keras.layers.Layer, instance_head: tf.keras.layers.Layer,
instance_decoder: Optional[tf.keras.Model] = None, instance_decoder: Optional[tf.keras.Model] = None,
post_processor: Optional[panoptic_deeplab_merge.PostProcessor] = None,
**kwargs): **kwargs):
""" """
Args: Args:
...@@ -46,13 +47,15 @@ class PanopticDeeplabModel(tf.keras.Model): ...@@ -46,13 +47,15 @@ class PanopticDeeplabModel(tf.keras.Model):
'semantic_decoder': semantic_decoder, 'semantic_decoder': semantic_decoder,
'instance_decoder': instance_decoder, 'instance_decoder': instance_decoder,
'semantic_head': semantic_head, 'semantic_head': semantic_head,
'instance_head': instance_head 'instance_head': instance_head,
'post_processor': post_processor
} }
self.backbone = backbone self.backbone = backbone
self.semantic_decoder = semantic_decoder self.semantic_decoder = semantic_decoder
self.instance_decoder = instance_decoder self.instance_decoder = instance_decoder
self.semantic_head = semantic_head self.semantic_head = semantic_head
self.instance_head = instance_head self.instance_head = instance_head
self.post_processor = post_processor
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor: def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
if training is None: if training is None:
...@@ -83,6 +86,10 @@ class PanopticDeeplabModel(tf.keras.Model): ...@@ -83,6 +86,10 @@ class PanopticDeeplabModel(tf.keras.Model):
'instance_center_regression': 'instance_center_regression':
instance_outputs['instance_center_regression'], instance_outputs['instance_center_regression'],
} }
if training:
return outputs
outputs = self.post_processor(outputs)
return outputs return outputs
@property @property
......
...@@ -24,6 +24,7 @@ from official.vision.beta.modeling import backbones ...@@ -24,6 +24,7 @@ from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.decoders import aspp from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_deeplab_merge
class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
...@@ -37,8 +38,9 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,8 +38,9 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
def test_panoptic_deeplab_network_creation( def test_panoptic_deeplab_network_creation(
self, input_size, level, low_level, shared_decoder, training): self, input_size, level, low_level, shared_decoder, training):
"""Test for creation of a panoptic deep lab network.""" """Test for creation of a panoptic deep lab network."""
batch_size = 2 if training else 1
num_classes = 10 num_classes = 10
inputs = np.random.rand(2, input_size, input_size, 3) inputs = np.random.rand(batch_size, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet(model_id=50) backbone = backbones.ResNet(model_id=50)
...@@ -62,16 +64,26 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -62,16 +64,26 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
low_level=low_level, low_level=low_level,
low_level_num_filters=(64, 32)) low_level_num_filters=(64, 32))
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=0.1,
thing_class_ids=[1, 2, 3, 4],
label_divisor=[256],
stuff_area_limit=4096,
ignore_label=0,
nms_kernel=41,
keep_k_centers=41)
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
semantic_decoder=semantic_decoder, semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder, instance_decoder=instance_decoder,
semantic_head=semantic_head, semantic_head=semantic_head,
instance_head=instance_head) instance_head=instance_head,
post_processor=post_processor)
outputs = model(inputs, training=training) outputs = model(inputs, training=training)
if training:
self.assertIn('segmentation_outputs', outputs) self.assertIn('segmentation_outputs', outputs)
self.assertIn('instance_center_prediction', outputs) self.assertIn('instance_center_prediction', outputs)
self.assertIn('instance_center_regression', outputs) self.assertIn('instance_center_regression', outputs)
...@@ -92,6 +104,13 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -92,6 +104,13 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
2], 2],
outputs['instance_center_regression'].numpy().shape) outputs['instance_center_regression'].numpy().shape)
else:
self.assertIn('panoptic_outputs', outputs)
self.assertIn('category_mask', outputs)
self.assertIn('instance_mask', outputs)
self.assertIn('instance_centers', outputs)
self.assertIn('instance_scores', outputs)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
level=[2, 3, 4], level=[2, 3, 4],
...@@ -122,12 +141,22 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -122,12 +141,22 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
low_level=low_level, low_level=low_level,
low_level_num_filters=(64, 32)) low_level_num_filters=(64, 32))
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=0.1,
thing_class_ids=[1, 2, 3, 4],
label_divisor=[256],
stuff_area_limit=4096,
ignore_label=0,
nms_kernel=41,
keep_k_centers=41)
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
semantic_decoder=semantic_decoder, semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder, instance_decoder=instance_decoder,
semantic_head=semantic_head, semantic_head=semantic_head,
instance_head=instance_head) instance_head=instance_head,
post_processor=post_processor)
config = model.get_config() config = model.get_config()
new_model = panoptic_deeplab_model.PanopticDeeplabModel.from_config(config) new_model = panoptic_deeplab_model.PanopticDeeplabModel.from_config(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment