Unverified Commit 2ad1ec15 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added configs for `SemanticHead` and `InstanceHead`

parent 31a8e466
...@@ -12,18 +12,15 @@ ...@@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Panoptic Mask R-CNN configuration definition.""" """Panoptic Deeplab configuration definition."""
import dataclasses import dataclasses
from typing import List, Optional, Union from typing import List, Tuple, Union
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import backbones from official.vision.beta.configs import backbones
from official.vision.beta.configs import decoders from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation
SEGMENTATION_HEAD = semantic_segmentation.SegmentationHead
_COCO_INPUT_PATH_BASE = 'coco/tfrecords' _COCO_INPUT_PATH_BASE = 'coco/tfrecords'
_COCO_TRAIN_EXAMPLES = 118287 _COCO_TRAIN_EXAMPLES = 118287
...@@ -31,17 +28,27 @@ _COCO_VAL_EXAMPLES = 5000 ...@@ -31,17 +28,27 @@ _COCO_VAL_EXAMPLES = 5000
@dataclasses.dataclass @dataclasses.dataclass
class InstanceCenterHead(semantic_segmentation.SegmentationHead): class PanopticDeeplabHead(hyperparams.Config):
"""Instance Center head config.""" """Panoptic Deeplab head config."""
# None, deeplabv3plus, panoptic_fpn_fusion, level: int = 3
# panoptic_deeplab_fusion or pyramid_fusion num_convs: int = 2
num_filters: int = 256
kernel_size: int = 5 kernel_size: int = 5
feature_fusion: Optional[str] = None use_depthwise_convolution: bool = False
low_level: Union[int, List[int]] = dataclasses.field( upsample_factor: int = 1
default_factory=lambda: [3, 2]) low_level: Union[List[int], Tuple[int]] = (3, 2)
low_level_num_filters: Union[int, List[int]] = dataclasses.field( low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
default_factory=lambda: [64, 32])
@dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead):
"""Semantic head config."""
prediction_kernel_size: int = 1
@dataclasses.dataclass
class InstanceHead(PanopticDeeplabHead):
"""Instance head config."""
prediction_kernel_size: int = 1
# pytype: disable=wrong-keyword-args # pytype: disable=wrong-keyword-args
@dataclasses.dataclass @dataclasses.dataclass
...@@ -55,7 +62,6 @@ class PanopticDeeplab(hyperparams.Config): ...@@ -55,7 +62,6 @@ class PanopticDeeplab(hyperparams.Config):
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet()) type='resnet', resnet=backbones.ResNet())
decoder: decoders.Decoder = decoders.Decoder(type='aspp') decoder: decoders.Decoder = decoders.Decoder(type='aspp')
semantic_head: SEGMENTATION_HEAD = SEGMENTATION_HEAD() semantic_head: SemanticHead = SemanticHead()
instance_head: InstanceCenterHead = InstanceCenterHead( instance_head: InstanceHead = InstanceHead()
low_level=[3, 2])
shared_decoder: bool = False shared_decoder: bool = False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment