Unverified Commit 609a50b0 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

support building segmenation head with `panoptic_fpn_fusion`

parent c18fc1bb
...@@ -69,8 +69,10 @@ def build_panoptic_maskrcnn( ...@@ -69,8 +69,10 @@ def build_panoptic_maskrcnn(
input_specs=segmentation_decoder_input_specs, input_specs=segmentation_decoder_input_specs,
model_config=segmentation_config, model_config=segmentation_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
decoder_config = segmentation_decoder.get_config()
else: else:
segmentation_decoder = None segmentation_decoder = None
decoder_config = maskrcnn_model.decoder.get_config()
segmentation_head_config = segmentation_config.head segmentation_head_config = segmentation_config.head
detection_head_config = model_config.detection_head detection_head_config = model_config.detection_head
...@@ -84,12 +86,15 @@ def build_panoptic_maskrcnn( ...@@ -84,12 +86,15 @@ def build_panoptic_maskrcnn(
num_filters=segmentation_head_config.num_filters, num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor, upsample_factor=segmentation_head_config.upsample_factor,
feature_fusion=segmentation_head_config.feature_fusion, feature_fusion=segmentation_head_config.feature_fusion,
decoder_min_level=segmentation_head_config.decoder_min_level,
decoder_max_level=segmentation_head_config.decoder_max_level,
low_level=segmentation_head_config.low_level, low_level=segmentation_head_config.low_level,
low_level_num_filters=segmentation_head_config.low_level_num_filters, low_level_num_filters=segmentation_head_config.low_level_num_filters,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
num_decoder_filters=decoder_config['num_filters'],
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
if model_config.generate_panoptic_masks: if model_config.generate_panoptic_masks:
......
...@@ -27,19 +27,18 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory ...@@ -27,19 +27,18 @@ from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase): class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
('resnet', (640, 640), 'dilated_resnet', 'fpn'), ('resnet', (640, 640), 'dilated_resnet', 'fpn', 'panoptic_fpn_fusion'),
('resnet', (640, 640), 'dilated_resnet', 'aspp'), ('resnet', (640, 640), 'dilated_resnet', 'aspp', 'deeplabv3plus'),
('resnet', (640, 640), None, 'fpn'), ('resnet', (640, 640), None, 'fpn', 'panoptic_fpn_fusion'),
('resnet', (640, 640), None, 'aspp'), ('resnet', (640, 640), None, 'aspp', 'deeplabv3plus'),
('resnet', (640, 640), None, None), ('resnet', (640, 640), None, None, 'panoptic_fpn_fusion'),
('resnet', (None, None), 'dilated_resnet', 'fpn'), ('resnet', (None, None), 'dilated_resnet', 'fpn', 'panoptic_fpn_fusion'),
('resnet', (None, None), 'dilated_resnet', 'aspp'), ('resnet', (None, None), 'dilated_resnet', 'aspp', 'deeplabv3plus'),
('resnet', (None, None), None, 'fpn'), ('resnet', (None, None), None, 'fpn', 'panoptic_fpn_fusion'),
('resnet', (None, None), None, 'aspp'), ('resnet', (None, None), None, 'aspp', 'deeplabv3plus'),
('resnet', (None, None), None, None) ('resnet', (None, None), None, None, 'deeplabv3plus'))
)
def test_builder(self, backbone_type, input_size, segmentation_backbone_type, def test_builder(self, backbone_type, input_size, segmentation_backbone_type,
segmentation_decoder_type): segmentation_decoder_type, fusion_type):
num_classes = 2 num_classes = 2
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
......
...@@ -53,17 +53,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -53,17 +53,16 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
shared_decoder, shared_decoder,
is_training=True): is_training=True):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 2
max_level = 7 max_level = 6
num_scales = 3 num_scales = 3
aspect_ratios = [1.0] aspect_ratios = [1.0]
anchor_size = 3 anchor_size = 3
resnet_model_id = 50 resnet_model_id = 50
segmentation_resnet_model_id = 50 segmentation_resnet_model_id = 50
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) aspp_decoder_level = 2
fpn_decoder_level = 3 fpn_decoder_level = 2
num_anchors_per_location = num_scales * len(aspect_ratios) num_anchors_per_location = num_scales * len(aspect_ratios)
image_size = 128 image_size = 128
images = np.random.rand(2, image_size, image_size, 3) images = np.random.rand(2, image_size, image_size, 3)
...@@ -115,15 +114,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,15 +114,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -179,16 +183,15 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -179,16 +183,15 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
shared_backbone, shared_decoder, shared_backbone, shared_decoder,
generate_panoptic_masks): generate_panoptic_masks):
num_classes = 3 num_classes = 3
min_level = 3 min_level = 2
max_level = 4 max_level = 6
num_scales = 3 num_scales = 3
aspect_ratios = [1.0] aspect_ratios = [1.0]
anchor_size = 3 anchor_size = 3
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) aspp_decoder_level = 2
fpn_decoder_level = 3 fpn_decoder_level = 2
class_agnostic_bbox_pred = False class_agnostic_bbox_pred = False
cascade_class_ensemble = False cascade_class_ensemble = False
...@@ -250,15 +253,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -250,15 +253,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -354,10 +362,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -354,10 +362,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_num_detections=100, max_num_detections=100,
stuff_classes_offset=90) stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) min_level = 2
fpn_decoder_level = 3 max_level = 6
aspp_decoder_level = 2
fpn_decoder_level = 2
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2) mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -370,15 +379,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -370,15 +379,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -397,8 +411,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -397,8 +411,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone=segmentation_backbone, segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder, segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head, segmentation_head=segmentation_head,
min_level=3, min_level=min_level,
max_level=7, max_level=max_level,
num_scales=3, num_scales=3,
aspect_ratios=[1.0], aspect_ratios=[1.0],
anchor_size=3) anchor_size=3)
...@@ -433,10 +447,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -433,10 +447,11 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
max_num_detections=100, max_num_detections=100,
stuff_classes_offset=90) stuff_classes_offset=90)
segmentation_resnet_model_id = 101 segmentation_resnet_model_id = 101
segmentation_output_stride = 16
aspp_dilation_rates = [6, 12, 18] aspp_dilation_rates = [6, 12, 18]
aspp_decoder_level = int(np.math.log2(segmentation_output_stride)) min_level = 2
fpn_decoder_level = 3 max_level = 6
aspp_decoder_level = 2
fpn_decoder_level = 2
shared_decoder = shared_decoder and shared_backbone shared_decoder = shared_decoder and shared_backbone
mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2) mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
mask_sampler_obj = mask_sampler.MaskSampler( mask_sampler_obj = mask_sampler.MaskSampler(
...@@ -449,15 +464,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -449,15 +464,20 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone = resnet.ResNet( segmentation_backbone = resnet.ResNet(
model_id=segmentation_resnet_model_id) model_id=segmentation_resnet_model_id)
if not shared_decoder: if not shared_decoder:
feature_fusion = 'deeplabv3plus'
level = aspp_decoder_level level = aspp_decoder_level
segmentation_decoder = aspp.ASPP( segmentation_decoder = aspp.ASPP(
level=level, dilation_rates=aspp_dilation_rates) level=level, dilation_rates=aspp_dilation_rates)
else: else:
feature_fusion = 'panoptic_fpn_fusion'
level = fpn_decoder_level level = fpn_decoder_level
segmentation_decoder = None segmentation_decoder = None
segmentation_head = segmentation_heads.SegmentationHead( segmentation_head = segmentation_heads.SegmentationHead(
num_classes=2, # stuff and common class for things, num_classes=2, # stuff and common class for things,
level=level, level=level,
feature_fusion=feature_fusion,
decoder_min_level=min_level,
decoder_max_level=max_level,
num_convs=2) num_convs=2)
model = panoptic_maskrcnn_model.PanopticMaskRCNNModel( model = panoptic_maskrcnn_model.PanopticMaskRCNNModel(
...@@ -476,8 +496,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase): ...@@ -476,8 +496,8 @@ class PanopticMaskRCNNModelTest(parameterized.TestCase, tf.test.TestCase):
segmentation_backbone=segmentation_backbone, segmentation_backbone=segmentation_backbone,
segmentation_decoder=segmentation_decoder, segmentation_decoder=segmentation_decoder,
segmentation_head=segmentation_head, segmentation_head=segmentation_head,
min_level=3, min_level=max_level,
max_level=7, max_level=max_level,
num_scales=3, num_scales=3,
aspect_ratios=[1.0], aspect_ratios=[1.0],
anchor_size=3) anchor_size=3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment