Commit 471451cc authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 441585903
parent 860d3101
...@@ -59,6 +59,8 @@ class TfExampleDecoder(common.TfExampleDecoder): ...@@ -59,6 +59,8 @@ class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config.""" """A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask. # Setting this to true will enable decoding category_mask and instance_mask.
include_panoptic_masks: bool = True include_panoptic_masks: bool = True
panoptic_category_mask_key: str = 'image/panoptic/category_mask'
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -24,23 +24,30 @@ from official.vision.ops import preprocess_ops ...@@ -24,23 +24,30 @@ from official.vision.ops import preprocess_ops
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder.""" """Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id, def __init__(
mask_binarize_threshold, include_panoptic_masks): self,
regenerate_source_id: bool,
mask_binarize_threshold: float,
include_panoptic_masks: bool,
panoptic_category_mask_key: str = 'image/panoptic/category_mask',
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'):
super(TfExampleDecoder, self).__init__( super(TfExampleDecoder, self).__init__(
include_mask=True, include_mask=True,
regenerate_source_id=regenerate_source_id, regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None) mask_binarize_threshold=None)
self._include_panoptic_masks = include_panoptic_masks self._include_panoptic_masks = include_panoptic_masks
self._panoptic_category_mask_key = panoptic_category_mask_key
self._panoptic_instance_mask_key = panoptic_instance_mask_key
keys_to_features = { keys_to_features = {
'image/segmentation/class/encoded': 'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')} tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_panoptic_masks: if include_panoptic_masks:
keys_to_features.update({ keys_to_features.update({
'image/panoptic/category_mask': panoptic_category_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value=''), tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask': panoptic_instance_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value='')}) tf.io.FixedLenFeature((), tf.string, default_value='')})
self._segmentation_keys_to_features = keys_to_features self._segmentation_keys_to_features = keys_to_features
...@@ -56,10 +63,10 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): ...@@ -56,10 +63,10 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
if self._include_panoptic_masks: if self._include_panoptic_masks:
category_mask = tf.io.decode_image( category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'], parsed_tensors[self._panoptic_category_mask_key],
channels=1) channels=1)
instance_mask = tf.io.decode_image( instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'], parsed_tensors[self._panoptic_instance_mask_key],
channels=1) channels=1)
category_mask.set_shape([None, None, 1]) category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1]) instance_mask.set_shape([None, None, 1])
......
...@@ -123,7 +123,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -123,7 +123,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
decoder = panoptic_maskrcnn_input.TfExampleDecoder( decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id, regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold, mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_panoptic_masks=decoder_cfg.include_panoptic_masks) include_panoptic_masks=decoder_cfg.include_panoptic_masks,
panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment