Unverified Commit 6ee54a60 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added `PanopticDeeplabModel`

parent a6a14de7
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Mask R-CNN configuration definition."""
import dataclasses
from typing import List, Optional, Union
from official.modeling import hyperparams
from official.vision.beta.configs import common
from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation
SEGMENTATION_HEAD = semantic_segmentation.SegmentationHead
_COCO_INPUT_PATH_BASE = 'coco/tfrecords'
_COCO_TRAIN_EXAMPLES = 118287
_COCO_VAL_EXAMPLES = 5000
@dataclasses.dataclass
class InstanceCenterHead(semantic_segmentation.SegmentationHead):
"""Instance Center head config."""
# None, deeplabv3plus, panoptic_fpn_fusion,
# panoptic_deeplab_fusion or pyramid_fusion
kernel_size: int = 5
feature_fusion: Optional[str] = None
low_level: Union[int, List[int]] = dataclasses.field(
default_factory=lambda: [3, 2])
low_level_num_filters: Union[int, List[int]] = dataclasses.field(
default_factory=lambda: [64, 32])
# pytype: disable=wrong-keyword-args
@dataclasses.dataclass
class PanopticDeeplab(hyperparams.Config):
"""Panoptic Mask R-CNN model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3
max_level: int = 6
norm_activation: common.NormActivation = common.NormActivation()
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
decoder: decoders.Decoder = decoders.Decoder(type='aspp')
semantic_head: SEGMENTATION_HEAD = SEGMENTATION_HEAD()
instance_head: InstanceCenterHead = InstanceCenterHead(
low_level=[3, 2])
shared_decoder: bool = False
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build Panoptic Deeplab model."""
from typing import Any, Mapping, Optional, Union
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Vision')
class PanopticDeeplabModel(tf.keras.Model):
"""Panoptic Deeplab model."""
def __init__(
self,
backbone: tf.keras.Model,
semantic_decoder: tf.keras.Model,
semantic_head: tf.keras.layers.Layer,
instance_head: tf.keras.layers.Layer,
instance_decoder: Optional[tf.keras.Model] = None,
**kwargs):
"""
Args:
backbone: a backbone network.
semantic_decoder: a decoder network. E.g. FPN.
semantic_head: segmentation head.
instance_head: instance center head .
instance_decoder: Optional decoder network for instance predictions.
**kwargs: keyword arguments to be passed.
"""
super(PanopticDeeplabModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'semantic_decoder': semantic_decoder,
'instance_decoder': instance_decoder,
'semantic_head': semantic_head,
'instance_head': instance_head
}
self.backbone = backbone
self.semantic_decoder = semantic_decoder
self.instance_decoder = instance_decoder
self.semantic_head = semantic_head
self.instance_head = instance_head
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
if training is None:
training = tf.keras.backend.learning_phase()
backbone_features = self.backbone(inputs, training=training)
semantic_features = self.semantic_decoder(
backbone_features, training=training)
if self.instance_decoder is None:
instance_features = semantic_features
else:
instance_features = self.instance_decoder(
backbone_features, training=training)
segmentation_outputs = self.semantic_head(
(backbone_features, semantic_features),
training=training)
instance_outputs = self.instance_head(
(backbone_features, instance_features),
training=training)
outputs = {
'segmentation_outputs': segmentation_outputs,
'instance_center_prediction':
instance_outputs['instance_center_prediction'],
'instance_center_regression':
instance_outputs['instance_center_regression'],
}
return outputs
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf.keras.Model, tf.keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(
backbone=self.backbone,
semantic_decoder=self.semantic_decoder,
semantic_head=self.semantic_head,
instance_head=self.instance_head)
if self.instance_decoder is not None:
items.update(instance_decoder=self.instance_decoder)
return items
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment