Unverified Commit 6112b4ce authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

move `stuff_classes_offset` to model config

parent 4c764049
...@@ -77,7 +77,6 @@ class DataConfig(maskrcnn.DataConfig): ...@@ -77,7 +77,6 @@ class DataConfig(maskrcnn.DataConfig):
class PanopticSegmentationGenerator(hyperparams.Config): class PanopticSegmentationGenerator(hyperparams.Config):
output_size: List[int] = dataclasses.field( output_size: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
stuff_classes_offset: int = 0
mask_binarize_threshold: float = 0.5 mask_binarize_threshold: float = 0.5
score_threshold: float = 0.05 score_threshold: float = 0.05
things_class_label: int = 1 things_class_label: int = 1
...@@ -93,9 +92,9 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN): ...@@ -93,9 +92,9 @@ class PanopticMaskRCNN(maskrcnn.MaskRCNN):
include_mask = True include_mask = True
shared_backbone: bool = True shared_backbone: bool = True
shared_decoder: bool = True shared_decoder: bool = True
stuff_classes_offset: int = 0
generate_panoptic_masks: bool = True generate_panoptic_masks: bool = True
panoptic_segmentation_generator: PanopticSegmentationGenerator = \ panoptic_segmentation_generator: PanopticSegmentationGenerator = PanopticSegmentationGenerator()
PanopticSegmentationGenerator()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -176,8 +175,8 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -176,8 +175,8 @@ def panoptic_maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
model=PanopticMaskRCNN( model=PanopticMaskRCNN(
num_classes=91, input_size=[1024, 1024, 3], num_classes=91, input_size=[1024, 1024, 3],
panoptic_segmentation_generator=PanopticSegmentationGenerator( panoptic_segmentation_generator=PanopticSegmentationGenerator(
output_size=[1024, 1024], output_size=[1024, 1024]),
stuff_classes_offset=90), stuff_classes_offset=90,
segmentation_model=SEGMENTATION_MODEL( segmentation_model=SEGMENTATION_MODEL(
num_classes=num_semantic_segmentation_classes, num_classes=num_semantic_segmentation_classes,
head=SEGMENTATION_HEAD(level=3))), head=SEGMENTATION_HEAD(level=3))),
......
...@@ -100,7 +100,7 @@ def build_panoptic_maskrcnn( ...@@ -100,7 +100,7 @@ def build_panoptic_maskrcnn(
panoptic_segmentation_generator.PanopticSegmentationGenerator( panoptic_segmentation_generator.PanopticSegmentationGenerator(
output_size=postprocessing_config.output_size, output_size=postprocessing_config.output_size,
max_num_detections=max_num_detections, max_num_detections=max_num_detections,
stuff_classes_offset=postprocessing_config.stuff_classes_offset, stuff_classes_offset=model_config.stuff_classes_offset,
mask_binarize_threshold=mask_binarize_threshold, mask_binarize_threshold=mask_binarize_threshold,
score_threshold=postprocessing_config.score_threshold, score_threshold=postprocessing_config.score_threshold,
things_class_label=postprocessing_config.things_class_label, things_class_label=postprocessing_config.things_class_label,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment