"git@developer.sourcefind.cn:change/sglang.git" did not exist on "f4f8a1b4d822eeaab95ea625d3f6b128baf4d0d6"
Unverified Commit 983ffd16 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

renamed `include_eval_masks`

parent e81f22d3
...@@ -48,13 +48,13 @@ class Parser(maskrcnn.Parser): ...@@ -48,13 +48,13 @@ class Parser(maskrcnn.Parser):
segmentation_ignore_label: int = 255 segmentation_ignore_label: int = 255
panoptic_ignore_label: int = 0 panoptic_ignore_label: int = 0
# Setting this to true will enable parsing category_mask and instance_mask # Setting this to true will enable parsing category_mask and instance_mask
include_eval_masks: bool = True include_panoptic_masks: bool = True
@dataclasses.dataclass @dataclasses.dataclass
class TfExampleDecoder(maskrcnn.TfExampleDecoder): class TfExampleDecoder(maskrcnn.TfExampleDecoder):
"""A simple TF Example decoder config.""" """A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask # Setting this to true will enable decoding category_mask and instance_mask
include_eval_masks: bool = True include_panoptic_masks: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -25,18 +25,18 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): ...@@ -25,18 +25,18 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder.""" """Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id, def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_eval_masks): mask_binarize_threshold, include_panoptic_masks):
super(TfExampleDecoder, self).__init__( super(TfExampleDecoder, self).__init__(
include_mask=True, include_mask=True,
regenerate_source_id=regenerate_source_id, regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None) mask_binarize_threshold=None)
self._include_eval_masks = include_eval_masks self._include_panoptic_masks= include_panoptic_masks
keys_to_features = { keys_to_features = {
'image/segmentation/class/encoded': 'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')} tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_eval_masks: if include_panoptic_masks:
keys_to_features.update({ keys_to_features.update({
'image/panoptic/category_mask': 'image/panoptic/category_mask':
tf.io.FixedLenFeature((), tf.string, default_value=''), tf.io.FixedLenFeature((), tf.string, default_value=''),
...@@ -54,7 +54,7 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): ...@@ -54,7 +54,7 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
segmentation_mask.set_shape([None, None, 1]) segmentation_mask.set_shape([None, None, 1])
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask}) decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_eval_masks: if self._include_panoptic_masks:
category_mask = tf.io.decode_image( category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'], parsed_tensors['image/panoptic/category_mask'],
channels=1) channels=1)
...@@ -96,7 +96,7 @@ class Parser(maskrcnn_input.Parser): ...@@ -96,7 +96,7 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size=None, segmentation_groundtruth_padded_size=None,
segmentation_ignore_label=255, segmentation_ignore_label=255,
panoptic_ignore_label=0, panoptic_ignore_label=0,
include_eval_masks=True, include_panoptic_masks=True,
dtype='float32'): dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
...@@ -138,7 +138,7 @@ class Parser(maskrcnn_input.Parser): ...@@ -138,7 +138,7 @@ class Parser(maskrcnn_input.Parser):
used for training and evaluation. used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be panoptic_ignore_label: `int` the pixels with ignore label will not be
used by the PQ evaluator. used by the PQ evaluator.
include_eval_masks: `bool`, if True, category_mask and instance_mask will include_panoptic_masks: `bool`, if True, category_mask and instance_mask will
be parsed. Set this to true if PQ evaluator is enabled. be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}. dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
""" """
...@@ -172,7 +172,7 @@ class Parser(maskrcnn_input.Parser): ...@@ -172,7 +172,7 @@ class Parser(maskrcnn_input.Parser):
self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size
self._segmentation_ignore_label = segmentation_ignore_label self._segmentation_ignore_label = segmentation_ignore_label
self._panoptic_ignore_label = panoptic_ignore_label self._panoptic_ignore_label = panoptic_ignore_label
self._include_eval_masks = include_eval_masks self._include_panoptic_masks= include_panoptic_masks
def _parse_train_data(self, data): def _parse_train_data(self, data):
"""Parses data for training. """Parses data for training.
...@@ -322,7 +322,7 @@ class Parser(maskrcnn_input.Parser): ...@@ -322,7 +322,7 @@ class Parser(maskrcnn_input.Parser):
'gt_segmentation_mask': segmentation_mask, 'gt_segmentation_mask': segmentation_mask,
'gt_segmentation_valid_mask': segmentation_valid_mask}) 'gt_segmentation_valid_mask': segmentation_valid_mask})
if self._include_eval_masks: if self._include_panoptic_masks:
panoptic_category_mask = _process_mask( panoptic_category_mask = _process_mask(
data['groundtruth_panoptic_category_mask'], data['groundtruth_panoptic_category_mask'],
self._panoptic_ignore_label, image_info) self._panoptic_ignore_label, image_info)
......
...@@ -122,7 +122,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -122,7 +122,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
decoder = panoptic_maskrcnn_input.TfExampleDecoder( decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id, regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold, mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_eval_masks=decoder_cfg.include_eval_masks) include_panoptic_masks=decoder_cfg.include_panoptic_masks)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
...@@ -150,7 +150,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -150,7 +150,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_groundtruth_padded_size, .segmentation_groundtruth_padded_size,
segmentation_ignore_label=params.parser.segmentation_ignore_label, segmentation_ignore_label=params.parser.segmentation_ignore_label,
panoptic_ignore_label=params.parser.panoptic_ignore_label, panoptic_ignore_label=params.parser.panoptic_ignore_label,
include_eval_masks=params.parser.include_eval_masks) include_panoptic_masks=params.parser.include_panoptic_masks)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment