Unverified Commit 31a8e466 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

use `SemanticHead` and `InstanceHead` from panoptic_deeplab_heads

parent abee356d
......@@ -22,8 +22,8 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import instance_center_head
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
......@@ -85,7 +85,6 @@ def build_panoptic_maskrcnn(
num_classes=segmentation_config.num_classes,
level=segmentation_head_config.level,
num_convs=segmentation_head_config.num_convs,
kernel_size=segmentation_head_config.kernel_size,
prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor,
......@@ -185,9 +184,9 @@ def build_panoptic_deeplab(
l2_regularizer=l2_regularizer)
semantic_head_config = model_config.semantic_head
instnace_head_config = model_config.instance_head
instance_head_config = model_config.instance_head
semantic_head = segmentation_heads.SegmentationHead(
semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes=model_config.num_classes,
level=semantic_head_config.level,
num_convs=semantic_head_config.num_convs,
......@@ -196,7 +195,6 @@ def build_panoptic_deeplab(
num_filters=semantic_head_config.num_filters,
use_depthwise_convolution=semantic_head_config.use_depthwise_convolution,
upsample_factor=semantic_head_config.upsample_factor,
feature_fusion=semantic_head_config.feature_fusion,
low_level=semantic_head_config.low_level,
low_level_num_filters=semantic_head_config.low_level_num_filters,
activation=norm_activation_config.activation,
......@@ -205,17 +203,16 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
instance_head = instance_center_head.InstanceCenterHead(
level=instnace_head_config.level,
num_convs=instnace_head_config.num_convs,
kernel_size=instnace_head_config.kernel_size,
prediction_kernel_size=instnace_head_config.prediction_kernel_size,
num_filters=instnace_head_config.num_filters,
use_depthwise_convolution=instnace_head_config.use_depthwise_convolution,
upsample_factor=instnace_head_config.upsample_factor,
feature_fusion=instnace_head_config.feature_fusion,
low_level=instnace_head_config.low_level,
low_level_num_filters=instnace_head_config.low_level_num_filters,
instance_head = panoptic_deeplab_heads.InstanceHead(
level=instance_head_config.level,
num_convs=instance_head_config.num_convs,
kernel_size=instance_head_config.kernel_size,
prediction_kernel_size=instance_head_config.prediction_kernel_size,
num_filters=instance_head_config.num_filters,
use_depthwise_convolution=instance_head_config.use_depthwise_convolution,
upsample_factor=instance_head_config.upsample_factor,
low_level=instance_head_config.low_level,
low_level_num_filters=instance_head_config.low_level_num_filters,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
......
......@@ -15,7 +15,6 @@
"""Tests for factory.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
......@@ -74,13 +73,9 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
decoder_type=['aspp', 'fpn'],
level=[2, 3, 4],
low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False],
fusion_type=[
'pyramid_fusion',
'panoptic_fpn_fusion',
'panoptic_deeplab_fusion']))
shared_decoder=[True, False]))
def test_builder(self, input_size, backbone_type, level,
low_level, decoder_type, shared_decoder, fusion_type):
low_level, decoder_type, shared_decoder):
num_classes = 10
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
......@@ -90,20 +85,18 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
input_size=input_size,
backbone=backbones.Backbone(type=backbone_type),
decoder=decoders.Decoder(type=decoder_type),
semantic_head=semantic_segmentation.SegmentationHead(
semantic_head=panoptic_deeplab_cfg.SemanticHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
instance_head=panoptic_deeplab_cfg.InstanceCenterHead(
low_level=low_level),
instance_head=panoptic_deeplab_cfg.InstanceHead(
level=level,
num_convs=1,
kernel_size=5,
prediction_kernel_size=1,
low_level=low_level,
feature_fusion=fusion_type),
low_level=low_level),
shared_decoder=shared_decoder)
l2_regularizer = tf.keras.regularizers.l2(5e-5)
......
......@@ -22,8 +22,7 @@ from tensorflow.python.distribute import combinations
from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import instance_center_head
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
......@@ -52,18 +51,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
instance_decoder = aspp.ASPP(
level=level, dilation_rates=[6, 12, 18])
semantic_head = segmentation_heads.SegmentationHead(
semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes,
level=level,
low_level=low_level,
low_level_num_filters=[64, 32],
feature_fusion='panoptic_deeplab_fusion')
low_level_num_filters=(64, 32))
instance_head = instance_center_head.InstanceCenterHead(
instance_head = panoptic_deeplab_heads.InstanceHead(
level=level,
low_level=low_level,
low_level_num_filters=[64, 32],
feature_fusion='panoptic_deeplab_fusion')
low_level_num_filters=(64, 32))
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
......@@ -114,18 +111,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
instance_decoder = aspp.ASPP(
level=level, dilation_rates=[6, 12, 18])
semantic_head = segmentation_heads.SegmentationHead(
semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes,
level=level,
low_level=low_level,
low_level_num_filters=[64, 32],
feature_fusion='panoptic_deeplab_fusion')
low_level_num_filters=(64, 32))
instance_head = instance_center_head.InstanceCenterHead(
instance_head = panoptic_deeplab_heads.InstanceHead(
level=level,
low_level=low_level,
low_level_num_filters=[64, 32],
feature_fusion='panoptic_deeplab_fusion')
low_level_num_filters=(64, 32))
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment