Unverified Commit e46a2497 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

add option to generate panoptic masks

parent 3d09f146
......@@ -222,15 +222,20 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
post_processing_config = model_config.post_processor
post_processor = panoptic_deeplab_merge.PostProcessor(
center_score_threshold=post_processing_config.center_score_threshold,
thing_class_ids=post_processing_config.thing_class_ids,
label_divisor=post_processing_config.label_divisor,
stuff_area_limit=post_processing_config.stuff_area_limit,
ignore_label=post_processing_config.ignore_label,
nms_kernel=post_processing_config.nms_kernel,
keep_k_centers=post_processing_config.keep_k_centers)
if model_config.generate_panoptic_masks:
post_processing_config = model_config.post_processor
post_processor = panoptic_deeplab_merge.PostProcessor(
output_size=post_processing_config.output_size,
center_score_threshold=post_processing_config.center_score_threshold,
thing_class_ids=post_processing_config.thing_class_ids,
label_divisor=post_processing_config.label_divisor,
stuff_area_limit=post_processing_config.stuff_area_limit,
ignore_label=post_processing_config.ignore_label,
nms_kernel=post_processing_config.nms_kernel,
keep_k_centers=post_processing_config.keep_k_centers,
rescale_predictions=post_processing_config.rescale_predictions)
else:
post_processor = None
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
......
......@@ -73,9 +73,11 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
decoder_type=['aspp', 'fpn'],
level=[2, 3, 4],
low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False]))
def test_builder(self, input_size, backbone_type, level,
low_level, decoder_type, shared_decoder):
shared_decoder=[True, False],
generate_panoptic_masks=[True, False]))
def test_builder(self, input_size, backbone_type,
level, low_level, decoder_type,
shared_decoder, generate_panoptic_masks):
num_classes = 10
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
......@@ -97,7 +99,8 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level),
shared_decoder=shared_decoder)
shared_decoder=shared_decoder,
generate_panoptic_masks=generate_panoptic_masks)
l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_panoptic_deeplab(
......
......@@ -57,7 +57,10 @@ class PanopticDeeplabModel(tf.keras.Model):
self.instance_head = instance_head
self.post_processor = post_processor
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
def call(
self, inputs: tf.Tensor,
image_info: tf.Tensor,
training: bool = None) -> tf.Tensor:
if training is None:
training = tf.keras.backend.learning_phase()
......@@ -81,15 +84,18 @@ class PanopticDeeplabModel(tf.keras.Model):
outputs = {
'segmentation_outputs': segmentation_outputs,
'instance_center_prediction':
instance_outputs['instance_center_prediction'],
'instance_center_regression':
instance_outputs['instance_center_regression'],
'instance_centers_heatmap':
instance_outputs['instance_centers_heatmap'],
'instance_centers_offset':
instance_outputs['instance_centers_offset'],
}
if training:
return outputs
outputs = self.post_processor(outputs)
if self.post_processor is not None:
panoptic_masks = self.post_processor(outputs, image_info)
outputs.update(panoptic_masks)
return outputs
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment