Unverified Commit ac671306 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added tests for `build_panoptic_deeplab` in panoptic factory

parent c3282abe
......@@ -17,10 +17,14 @@
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
......@@ -61,5 +65,53 @@ class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
model_config=model_config,
l2_regularizer=l2_regularizer)
class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
input_size=[(640, 640), (512, 512)],
backbone_type=['resnet', 'dilated_resnet'],
decoder_type=['aspp', 'fpn'],
level=[2, 3, 4],
low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False],
fusion_type=[
'pyramid_fusion',
'panoptic_fpn_fusion',
'panoptic_deeplab_fusion']))
def test_builder(self, input_size, backbone_type, level,
low_level, decoder_type, shared_decoder, fusion_type):
num_classes = 10
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
model_config = panoptic_deeplab_cfg.PanopticDeeplab(
num_classes=num_classes,
input_size=input_size,
backbone=backbones.Backbone(type=backbone_type),
decoder=decoders.Decoder(type=decoder_type),
semantic_head=semantic_segmentation.SegmentationHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
instance_head=panoptic_deeplab_cfg.InstanceCenterHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
shared_decoder=shared_decoder)
l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_panoptic_deeplab(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment