Unverified Commit c3282abe authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added `build_panoptic_deeplab` in panoptic factory

parent 8a8d5fab
...@@ -20,7 +20,10 @@ from official.vision.beta.modeling import backbones ...@@ -20,7 +20,10 @@ from official.vision.beta.modeling import backbones
from official.vision.beta.modeling import factory as models_factory from official.vision.beta.modeling import factory as models_factory
from official.vision.beta.modeling.decoders import factory as decoder_factory from official.vision.beta.modeling.decoders import factory as decoder_factory
from official.vision.beta.modeling.heads import segmentation_heads from official.vision.beta.modeling.heads import segmentation_heads
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_deeplab as panoptic_deeplab_cfg
from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg from official.vision.beta.projects.panoptic_maskrcnn.configs import panoptic_maskrcnn as panoptic_maskrcnn_cfg
from official.vision.beta.projects.panoptic_maskrcnn.modeling.heads import instance_center_head
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_deeplab_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model from official.vision.beta.projects.panoptic_maskrcnn.modeling import panoptic_maskrcnn_model
from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator from official.vision.beta.projects.panoptic_maskrcnn.modeling.layers import panoptic_segmentation_generator
...@@ -82,6 +85,7 @@ def build_panoptic_maskrcnn( ...@@ -82,6 +85,7 @@ def build_panoptic_maskrcnn(
num_classes=segmentation_config.num_classes, num_classes=segmentation_config.num_classes,
level=segmentation_head_config.level, level=segmentation_head_config.level,
num_convs=segmentation_head_config.num_convs, num_convs=segmentation_head_config.num_convs,
kernel_size=segmentation_head_config.kernel_size,
prediction_kernel_size=segmentation_head_config.prediction_kernel_size, prediction_kernel_size=segmentation_head_config.prediction_kernel_size,
num_filters=segmentation_head_config.num_filters, num_filters=segmentation_head_config.num_filters,
upsample_factor=segmentation_head_config.upsample_factor, upsample_factor=segmentation_head_config.upsample_factor,
...@@ -141,3 +145,88 @@ def build_panoptic_maskrcnn( ...@@ -141,3 +145,88 @@ def build_panoptic_maskrcnn(
aspect_ratios=model_config.anchor.aspect_ratios, aspect_ratios=model_config.anchor.aspect_ratios,
anchor_size=model_config.anchor.anchor_size) anchor_size=model_config.anchor.anchor_size)
return model return model
def build_panoptic_deeplab(
input_specs: tf.keras.layers.InputSpec,
model_config: panoptic_deeplab_cfg.PanopticDeeplab,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Builds Panoptic Deeplab model.
Args:
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
model_config: Config instance for the panoptic maskrcnn model.
l2_regularizer: Optional `tf.keras.regularizers.Regularizer`, if specified,
the model is built with the provided regularization layer.
Returns:
tf.keras.Model for the panoptic segmentation model.
"""
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
semantic_decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
if model_config.shared_decoder:
instance_decoder = None
else:
# TODO(srihari-humbarwadi): decouple semantic and
# instance decoder types
instance_decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
semantic_head_config = model_config.semantic_head
instnace_head_config = model_config.instance_head
semantic_head = segmentation_heads.SegmentationHead(
num_classes=model_config.num_classes,
level=semantic_head_config.level,
num_convs=semantic_head_config.num_convs,
kernel_size=semantic_head_config.kernel_size,
prediction_kernel_size=semantic_head_config.prediction_kernel_size,
num_filters=semantic_head_config.num_filters,
use_depthwise_convolution=semantic_head_config.use_depthwise_convolution,
upsample_factor=semantic_head_config.upsample_factor,
feature_fusion=semantic_head_config.feature_fusion,
low_level=semantic_head_config.low_level,
low_level_num_filters=semantic_head_config.low_level_num_filters,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
instance_head = instance_center_head.InstanceCenterHead(
level=instnace_head_config.level,
num_convs=instnace_head_config.num_convs,
kernel_size=instnace_head_config.kernel_size,
prediction_kernel_size=instnace_head_config.prediction_kernel_size,
num_filters=instnace_head_config.num_filters,
use_depthwise_convolution=instnace_head_config.use_depthwise_convolution,
upsample_factor=instnace_head_config.upsample_factor,
feature_fusion=instnace_head_config.feature_fusion,
low_level=instnace_head_config.low_level,
low_level_num_filters=instnace_head_config.low_level_num_filters,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
model = panoptic_deeplab_model.PanopticDeeplabModel(
backbone=backbone,
semantic_decoder=semantic_decoder,
instance_decoder=instance_decoder,
semantic_head=semantic_head,
instance_head=instance_head)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment