Unverified Commit 9eed3c5c authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added config for `PanopticSegmentationGenerator`

parent b7ca1c0f
......@@ -20,6 +20,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation
......@@ -52,6 +53,20 @@ class DataConfig(maskrcnn.DataConfig):
"""Input config for training."""
parser: Parser = Parser()
# @dataclasses.dataclass
@dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config):
output_size: List[int] = dataclasses.field(
default_factory=list)
stuff_classes_offset: int = 0
mask_binarize_threshold: float = 0.5
score_threshold: float = 0.05
things_class_label: int = 1
void_class_label: int = 0
void_instance_id: int = 0
@dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN):
"""Panoptic Mask R-CNN model config."""
......@@ -60,6 +75,8 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN):
include_mask = True
shared_backbone: bool = True
shared_decoder: bool = True
panoptic_segmentation_generator: PanopticSegmentationGenerator = \
PanopticSegmentationGenerator()
@dataclasses.dataclass
......@@ -112,8 +129,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
init_checkpoint_modules=['backbone'],
model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024],
stuff_classes_offset=90),
segmentation_model=SEGMENTATION_MODEL(
num_classes=91,
num_classes=110,
head=SEGMENTATION_HEAD(level=3))),
losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment