Unverified Commit 983ffd16 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

renamed `include_eval_masks`

parent e81f22d3
......@@ -48,13 +48,13 @@ class Parser(maskrcnn.Parser):
segmentation_ignore_label: int = 255
panoptic_ignore_label: int = 0
# Setting this to true will enable parsing category_mask and instance_mask
include_eval_masks: bool = True
include_panoptic_masks: bool = True
@dataclasses.dataclass
class TfExampleDecoder(maskrcnn.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask
include_eval_masks: bool = True
include_panoptic_masks: bool = True
@dataclasses.dataclass
......
......@@ -25,18 +25,18 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_eval_masks):
mask_binarize_threshold, include_panoptic_masks):
super(TfExampleDecoder, self).__init__(
include_mask=True,
regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None)
self._include_eval_masks = include_eval_masks
self._include_panoptic_masks= include_panoptic_masks
keys_to_features = {
'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_eval_masks:
if include_panoptic_masks:
keys_to_features.update({
'image/panoptic/category_mask':
tf.io.FixedLenFeature((), tf.string, default_value=''),
......@@ -54,7 +54,7 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
segmentation_mask.set_shape([None, None, 1])
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_eval_masks:
if self._include_panoptic_masks:
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'],
channels=1)
......@@ -96,7 +96,7 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size=None,
segmentation_ignore_label=255,
panoptic_ignore_label=0,
include_eval_masks=True,
include_panoptic_masks=True,
dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -138,7 +138,7 @@ class Parser(maskrcnn_input.Parser):
used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be
used by the PQ evaluator.
include_eval_masks: `bool`, if True, category_mask and instance_mask will
include_panoptic_masks: `bool`, if True, category_mask and instance_mask will
be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
......@@ -172,7 +172,7 @@ class Parser(maskrcnn_input.Parser):
self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size
self._segmentation_ignore_label = segmentation_ignore_label
self._panoptic_ignore_label = panoptic_ignore_label
self._include_eval_masks = include_eval_masks
self._include_panoptic_masks= include_panoptic_masks
def _parse_train_data(self, data):
"""Parses data for training.
......@@ -322,7 +322,7 @@ class Parser(maskrcnn_input.Parser):
'gt_segmentation_mask': segmentation_mask,
'gt_segmentation_valid_mask': segmentation_valid_mask})
if self._include_eval_masks:
if self._include_panoptic_masks:
panoptic_category_mask = _process_mask(
data['groundtruth_panoptic_category_mask'],
self._panoptic_ignore_label, image_info)
......
......@@ -122,7 +122,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_eval_masks=decoder_cfg.include_eval_masks)
include_panoptic_masks=decoder_cfg.include_panoptic_masks)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......@@ -150,7 +150,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_groundtruth_padded_size,
segmentation_ignore_label=params.parser.segmentation_ignore_label,
panoptic_ignore_label=params.parser.panoptic_ignore_label,
include_eval_masks=params.parser.include_eval_masks)
include_panoptic_masks=params.parser.include_panoptic_masks)
reader = input_reader_factory.input_reader_generator(
params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment