Unverified Commit 31a8e466 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

use `SemanticHead` and `InstanceHead` from panoptic_deeplab_heads

parent abee356d
...@@ -22,8 +22,8 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory ...@@ -22,8 +22,8 @@ from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import instance_center_head
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
...@@ -85,7 +85,6 @@ def build_panoptic_maskrcnn( ...@@ -85,7 +85,6 @@ def build_panoptic_maskrcnn(
num_classes=segmentation_config.num_classes, num_classes=segmentation_config.num_classes,
level=segmentation_head_config.level, level=segmentation_head_config.level,
num_convs=segmentation_head_config.num_convs, num_convs=segmentation_head_config.num_convs,
kernel_size=segmentation_head_config.kernel_size,
prediction_kernel_size=segmentation_head_config.prediction_kernel_size, prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
num_filters=segmentation_head_config.num_filters, num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor, upsample_factor=segmentation_head_config.upsample_factor,
...@@ -185,9 +184,9 @@ def build_panoptic_deeplab( ...@@ -185,9 +184,9 @@ def build_panoptic_deeplab(
l2_regularizer=l2_regularizer) l2_regularizer=l2_regularizer)
semantic_head_config = model_config.semantic_head semantic_head_config = model_config.semantic_head
instnace_head_config = model_config.instance_head instance_head_config = model_config.instance_head
semantic_head = segmentation_heads.SegmentationHead( semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
level=semantic_head_config.level, level=semantic_head_config.level,
num_convs=semantic_head_config.num_convs, num_convs=semantic_head_config.num_convs,
...@@ -196,7 +195,6 @@ def build_panoptic_deeplab( ...@@ -196,7 +195,6 @@ def build_panoptic_deeplab(
num_filters=semantic_head_config.num_filters, num_filters=semantic_head_config.num_filters,
use_depthwise_convolution=semantic_head_config.use_depthwise_convolution, use_depthwise_convolution=semantic_head_config.use_depthwise_convolution,
upsample_factor=semantic_head_config.upsample_factor, upsample_factor=semantic_head_config.upsample_factor,
feature_fusion=semantic_head_config.feature_fusion,
low_level=semantic_head_config.low_level, low_level=semantic_head_config.low_level,
low_level_num_filters=semantic_head_config.low_level_num_filters, low_level_num_filters=semantic_head_config.low_level_num_filters,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
...@@ -205,17 +203,16 @@ def build_panoptic_deeplab( ...@@ -205,17 +203,16 @@ def build_panoptic_deeplab(
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer)
instance_head = instance_center_head.InstanceCenterHead( instance_head = panoptic_deeplab_heads.InstanceHead(
level=instnace_head_config.level, level=instance_head_config.level,
num_convs=instnace_head_config.num_convs, num_convs=instance_head_config.num_convs,
kernel_size=instnace_head_config.kernel_size, kernel_size=instance_head_config.kernel_size,
prediction_kernel_size=instnace_head_config.prediction_kernel_size, prediction_kernel_size=instance_head_config.prediction_kernel_size,
num_filters=instnace_head_config.num_filters, num_filters=instance_head_config.num_filters,
use_depthwise_convolution=instnace_head_config.use_depthwise_convolution, use_depthwise_convolution=instance_head_config.use_depthwise_convolution,
upsample_factor=instnace_head_config.upsample_factor, upsample_factor=instance_head_config.upsample_factor,
feature_fusion=instnace_head_config.feature_fusion, low_level=instance_head_config.low_level,
low_level=instnace_head_config.low_level, low_level_num_filters=instance_head_config.low_level_num_filters,
low_level_num_filters=instnace_head_config.low_level_num_filters,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tests for factory.py.""" """Tests for factory.py."""
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
...@@ -74,13 +73,9 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -74,13 +73,9 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
decoder_type=['aspp', 'fpn'], decoder_type=['aspp', 'fpn'],
level=[2, 3, 4], level=[2, 3, 4],
low_level=[(4, 3), (3, 2)], low_level=[(4, 3), (3, 2)],
shared_decoder=[True, False], shared_decoder=[True, False]))
fusion_type=[
'pyramid_fusion',
'panoptic_fpn_fusion',
'panoptic_deeplab_fusion']))
def test_builder(self, input_size, backbone_type, level, def test_builder(self, input_size, backbone_type, level,
low_level, decoder_type, shared_decoder, fusion_type): low_level, decoder_type, shared_decoder):
num_classes = 10 num_classes = 10
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3]) shape=[None, input_size[0], input_size[1], 3])
...@@ -90,20 +85,18 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -90,20 +85,18 @@ class PanopticDeeplabBuilderTest(parameterized.TestCase, tf.test.TestCase):
input_size=input_size, input_size=input_size,
backbone=backbones.Backbone(type=backbone_type), backbone=backbones.Backbone(type=backbone_type),
decoder=decoders.Decoder(type=decoder_type), decoder=decoders.Decoder(type=decoder_type),
semantic_head=semantic_segmentation.SegmentationHead( semantic_head=panoptic_deeplab_cfg.SemanticHead(
level=level, level=level,
num_convs=1, num_convs=1,
kernel_size=5, kernel_size=5,
prediction_kernel_size=1, prediction_kernel_size=1,
low_level=low_level, low_level=low_level),
feature_fusion=fusion_type), instance_head=panoptic_deeplab_cfg.InstanceHead(
instance_head=panoptic_deeplab_cfg.InstanceCenterHead(
level=level, level=level,
num_convs=1, num_convs=1,
kernel_size=5, kernel_size=5,
prediction_kernel_size=1, prediction_kernel_size=1,
low_level=low_level, low_level=low_level),
feature_fusion=fusion_type),
shared_decoder=shared_decoder) shared_decoder=shared_decoder)
l2_regularizer = tf.keras.regularizers.l2(5e-5) l2_regularizer = tf.keras.regularizers.l2(5e-5)
......
...@@ -22,8 +22,7 @@ from tensorflow.python.distribute import combinations ...@@ -22,8 +22,7 @@ from tensorflow.python.distribute import combinations
from official.vision.beta.modeling import backbones from official.vision.beta.modeling import backbones
from official.vision.beta.modeling.decoders import aspp from official.vision.beta.modeling.decoders import aspp
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import panoptic_deeplab_heads
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import instance_center_head
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
...@@ -52,18 +51,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -52,18 +51,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
instance_decoder = aspp.ASPP( instance_decoder = aspp.ASPP(
level=level, dilation_rates=[6, 12, 18]) level=level, dilation_rates=[6, 12, 18])
semantic_head = segmentation_heads.SegmentationHead( semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes, num_classes,
level=level, level=level,
low_level=low_level, low_level=low_level,
low_level_num_filters=[64, 32], low_level_num_filters=(64, 32))
feature_fusion='panoptic_deeplab_fusion')
instance_head = instance_center_head.InstanceCenterHead( instance_head = panoptic_deeplab_heads.InstanceHead(
level=level, level=level,
low_level=low_level, low_level=low_level,
low_level_num_filters=[64, 32], low_level_num_filters=(64, 32))
feature_fusion='panoptic_deeplab_fusion')
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
...@@ -114,18 +111,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -114,18 +111,16 @@ class PanopticDeeplabNetworkTest(parameterized.TestCase, tf.test.TestCase):
instance_decoder = aspp.ASPP( instance_decoder = aspp.ASPP(
level=level, dilation_rates=[6, 12, 18]) level=level, dilation_rates=[6, 12, 18])
semantic_head = segmentation_heads.SegmentationHead( semantic_head = panoptic_deeplab_heads.SemanticHead(
num_classes, num_classes,
level=level, level=level,
low_level=low_level, low_level=low_level,
low_level_num_filters=[64, 32], low_level_num_filters=(64, 32))
feature_fusion='panoptic_deeplab_fusion')
instance_head = instance_center_head.InstanceCenterHead( instance_head = panoptic_deeplab_heads.InstanceHead(
level=level, level=level,
low_level=low_level, low_level=low_level,
low_level_num_filters=[64, 32], low_level_num_filters=(64, 32))
feature_fusion='panoptic_deeplab_fusion')
model = panoptic_deeplab_model.PanopticDeeplabModel( model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone, backbone=backbone,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment