Unverified Commit 9eed3c5c authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added config for `PanopticSegmentationGenerator`

parent b7ca1c0f
...@@ -20,6 +20,7 @@ from typing import List, Optional ...@@ -20,6 +20,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
...@@ -52,6 +53,20 @@ class DataConfig(maskrcnn.DataConfig): ...@@ -52,6 +53,20 @@ class DataConfig(maskrcnn.DataConfig):
"""Input config for training.""" """Input config for training."""
parser: Parser = Parser() parser: Parser = Parser()
# @dataclasses.dataclass
@dataclasses.dataclass
class PanopticSegmentationGenerator(hyperparams.Config):
output_size: List[int] = dataclasses.field(
default_factory=list)
stuff_classes_offset: int = 0
mask_binarize_threshold: float = 0.5
score_threshold: float = 0.05
things_class_label: int = 1
void_class_label: int = 0
void_instance_id: int = 0
@dataclasses.dataclass @dataclasses.dataclass
class PanopticMaskRCNN(maskrcnn.MaskRCNN): class PanopticMaskRCNN(maskrcnn.MaskRCNN):
"""Panoptic Mask R-CNN model config.""" """Panoptic Mask R-CNN model config."""
...@@ -60,6 +75,8 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN): ...@@ -60,6 +75,8 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN):
include_mask = True include_mask = True
shared_backbone: bool = True shared_backbone: bool = True
shared_decoder: bool = True shared_decoder: bool = True
panoptic_segmentation_generator: PanopticSegmentationGenerator = \
PanopticSegmentationGenerator()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -112,8 +129,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -112,8 +129,11 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
init_checkpoint_modules=['backbone'], init_checkpoint_modules=['backbone'],
model=PanopticMaskRCNN( model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024],
stuff_classes_offset=90),
segmentation_model=SEGMENTATION_MODEL( segmentation_model=SEGMENTATION_MODEL(
num_classes=91, num_classes=110,
head=SEGMENTATION_HEAD(level=3))), head=SEGMENTATION_HEAD(level=3))),
losses=Losses(l2_weight_decay=0.00004), losses=Losses(l2_weight_decay=0.00004),
train_data=DataConfig( train_data=DataConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment