"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "020efa74ab2ffc9784642b597d4d2cb2be3f130e"
Unverified Commit 00cbcad1 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added config flag to turn on/off panoptic eval masks in dataloader

parent 485f4618
...@@ -21,6 +21,7 @@ from typing import List, Optional ...@@ -21,6 +21,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
...@@ -46,10 +47,26 @@ class Parser(maskrcnn.Parser): ...@@ -46,10 +47,26 @@ class Parser(maskrcnn.Parser):
segmentation_groundtruth_padded_size: List[int] = dataclasses.field( segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
segmentation_ignore_label: int = 255 segmentation_ignore_label: int = 255
panoptic_ignore_label: int = 0
# Setting this to true will enable parsing category_mask and instance_mask
include_eval_masks: bool = True
@dataclasses.dataclass
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask
include_eval_masks: bool = True
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: TfExampleDecoder = TfExampleDecoder()
@dataclasses.dataclass @dataclasses.dataclass
class DataConfig(maskrcnn.DataConfig): class DataConfig(maskrcnn.DataConfig):
"""Input config for training.""" """Input config for training."""
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -121,7 +121,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -121,7 +121,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
decoder = panoptic_maskrcnn_input.TfExampleDecoder( decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id, regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold) mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_eval_masks=decoder_cfg.include_eval_masks)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
...@@ -147,7 +148,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -147,7 +148,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_resize_eval_groundtruth, .segmentation_resize_eval_groundtruth,
segmentation_groundtruth_padded_size=params.parser segmentation_groundtruth_padded_size=params.parser
.segmentation_groundtruth_padded_size, .segmentation_groundtruth_padded_size,
segmentation_ignore_label=params.parser.segmentation_ignore_label) segmentation_ignore_label=params.parser.segmentation_ignore_label,
panoptic_ignore_label=params.parse.panoptic_ignore_label,
include_eval_masks=params.parser.include_eval_masks)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment