Unverified Commit c8e0233b authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added tests for panoptic_deeplab_fusion

parent 78949f92
......@@ -26,14 +26,17 @@ from official.vision.beta.modeling.heads import segmentation_heads
class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(2, 'pyramid_fusion', None, None),
(3, 'pyramid_fusion', None, None),
(2, 'panoptic_fpn_fusion', 2, 5),
(2, 'panoptic_fpn_fusion', 2, 6),
(3, 'panoptic_fpn_fusion', 3, 5),
(3, 'panoptic_fpn_fusion', 3, 6))
(2, 'pyramid_fusion', None, None, 2, 48),
(3, 'pyramid_fusion', None, None, 2, 48),
(2, 'panoptic_fpn_fusion', 2, 5, 2, 48),
(2, 'panoptic_fpn_fusion', 2, 6, 2, 48),
(3, 'panoptic_fpn_fusion', 3, 5, 2, 48),
(3, 'panoptic_fpn_fusion', 3, 6, 2, 48),
(4, 'panoptic_deeplab_fusion', None, None, (4, 3), (64, 32)),
(4, 'panoptic_deeplab_fusion', None, None, (3, 2), (64, 32)))
def test_forward(self, level, feature_fusion,
decoder_min_level, decoder_max_level):
decoder_min_level, decoder_max_level,
low_level, low_level_num_filters):
backbone_features = {
'3': np.random.rand(2, 128, 128, 16),
'4': np.random.rand(2, 64, 64, 16),
......@@ -45,14 +48,16 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
'5': np.random.rand(2, 32, 32, 64),
'6': np.random.rand(2, 16, 16, 64),
}
if feature_fusion == 'panoptic_fpn_fusion':
num_classes = 10
if 'panoptic' in feature_fusion:
backbone_features['2'] = np.random.rand(2, 256, 256, 16)
decoder_features['2'] = np.random.rand(2, 256, 256, 64)
head = segmentation_heads.SegmentationHead(
num_classes=10,
num_classes=num_classes,
level=level,
low_level=low_level,
low_level_num_filters=low_level_num_filters,
feature_fusion=feature_fusion,
decoder_min_level=decoder_min_level,
decoder_max_level=decoder_max_level,
......@@ -60,14 +65,18 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
logits = head((backbone_features, decoder_features))
if level in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2], 10
])
if str(level) in decoder_features:
if feature_fusion == 'panoptic_deeplab_fusion':
h, w = decoder_features[str(low_level[-1])].shape[1:3]
else:
h, w = decoder_features[str(level)].shape[1:3]
self.assertAllEqual(
logits.numpy().shape,
[2, h, w, num_classes])
def test_serialize_deserialize(self):
head = segmentation_heads.SegmentationHead(num_classes=10, level=3)
head = segmentation_heads.SegmentationHead(
num_classes=10, level=3)
config = head.get_config()
new_head = segmentation_heads.SegmentationHead.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment