"...composable_kernel.git" did not exist on "4e075420b9d9d3e3cf4e53012cabcb42b93e9186"
Unverified Commit ac671306 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added tests for `build_panoptic_deeplab` in panoptic factory

parent c3282abe
...@@ -17,10 +17,14 @@ ...@@ -17,10 +17,14 @@
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.beta.configs import backbones from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory from official.vision.beta.projects.panoptic_maskrcnn.modeling import factory
...@@ -61,5 +65,53 @@ class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -61,5 +65,53 @@ class PanopticMaskRCNNBuilderTest(parameterized.TestCase, tf.test.TestCase):
model_config=model_config, model_config=model_config,
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
@combinations.generate(
combinations.combine(
input_size=[(640, 640), (512, 512)],
backbone_type=['resnet', 'dilated_resnet'],
decoder_type=['aspp', 'fpn'],
level=[2, 3, 4],
low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False],
fusion_type=[
'pyramid_fusion',
'panoptic_fpn_fusion',
'panoptic_deeplab_fusion']))
def test_builder(self, input_size, backbone_type, level,
low_level, decoder_type, shared_decoder, fusion_type):
num_classes = 10
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
model_config = panoptic_deeplab_cfg.PanopticDeeplab(
num_classes=num_classes,
input_size=input_size,
backbone=backbones.Backbone(type=backbone_type),
decoder=decoders.Decoder(type=decoder_type),
semantic_head=semantic_segmentation.SegmentationHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
instance_head=panoptic_deeplab_cfg.InstanceCenterHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
shared_decoder=shared_decoder)
l2_regularizer = tf.keras.regularizers.l2(5e-5)
_ = factory.build_panoptic_deeplab(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment