Unverified Commit 485f4618 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

parse and decode category_mask and instance_mask

parent 895e68a0
......@@ -24,25 +24,51 @@ from official.vision.beta.ops import preprocess_ops
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id, mask_binarize_threshold):
def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_eval_masks):
super(TfExampleDecoder, self).__init__(
include_mask=True,
regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None)
self._segmentation_keys_to_features = {
self._include_eval_masks = include_eval_masks
keys_to_features = {
'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')
}
tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_eval_masks:
keys_to_features.update({
'image/panoptic/category_mask':
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask':
tf.io.FixedLenFeature((), tf.string, default_value='')})
self._segmentation_keys_to_features = keys_to_features
def decode(self, serialized_example):
decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example)
segmentation_parsed_tensors = tf.io.parse_single_example(
parsed_tensors = tf.io.parse_single_example(
serialized_example, self._segmentation_keys_to_features)
segmentation_mask = tf.io.decode_image(
segmentation_parsed_tensors['image/segmentation/class/encoded'],
parsed_tensors['image/segmentation/class/encoded'],
channels=1)
segmentation_mask.set_shape([None, None, 1])
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_eval_masks:
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'],
channels=1)
instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'],
channels=1)
category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1])
decoded_tensors.update({
'groundtruth_panoptic_category_mask':
category_mask,
'groundtruth_panoptic_instance_mask':
instance_mask})
return decoded_tensors
......@@ -69,6 +95,8 @@ class Parser(maskrcnn_input.Parser):
segmentation_resize_eval_groundtruth=True,
segmentation_groundtruth_padded_size=None,
segmentation_ignore_label=255,
panoptic_ignore_label=0,
include_eval_masks=True,
dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -106,8 +134,12 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size: `Tensor` or `list` for [height,
width]. When resize_eval_groundtruth is set to False, the groundtruth
masks are padded to this size.
segmentation_ignore_label: `int` the pixel with ignore label will not used
for training and evaluation.
segmentation_ignore_label: `int` the pixels with ignore label will not be
used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be
used by the PQ evaluator.
include_eval_masks: `bool`, if True, category_mask and instance_mask will
be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
"""
super(Parser, self).__init__(
......@@ -139,6 +171,8 @@ class Parser(maskrcnn_input.Parser):
'specified when segmentation_resize_eval_groundtruth is False.')
self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size
self._segmentation_ignore_label = segmentation_ignore_label
self._panoptic_ignore_label = panoptic_ignore_label
self._include_eval_masks = include_eval_masks
def _parse_train_data(self, data):
"""Parses data for training.
......@@ -250,39 +284,54 @@ class Parser(maskrcnn_input.Parser):
shape [height_l, width_l, 4] representing anchor boxes at each
level.
"""
segmentation_mask = tf.cast(
data['groundtruth_segmentation_mask'], tf.float32)
segmentation_mask = tf.reshape(
segmentation_mask, shape=[1, data['height'], data['width'], 1])
segmentation_mask += 1
def _process_mask(mask, ignore_label, image_info):
mask = tf.cast(mask, dtype=tf.float32)
mask = tf.reshape(mask, shape=[1, data['height'], data['width'], 1])
mask += 1
if self._segmentation_resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_scale = image_info[2, :]
offset = image_info[3, :]
mask = preprocess_ops.resize_and_crop_masks(
mask, image_scale, self._output_size, offset)
else:
mask = tf.image.pad_to_bounding_box(
mask, 0, 0,
self._segmentation_groundtruth_padded_size[0],
self._segmentation_groundtruth_padded_size[1])
mask -= 1
# Assign ignore label to the padded region.
mask = tf.where(
tf.equal(mask, -1),
ignore_label * tf.ones_like(mask),
mask)
mask = tf.squeeze(mask, axis=0)
return mask
image, labels = super(Parser, self)._parse_eval_data(data)
image_info = labels['image_info']
if self._segmentation_resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images.
image_info = labels['image_info']
image_scale = image_info[2, :]
offset = image_info[3, :]
segmentation_mask = preprocess_ops.resize_and_crop_masks(
segmentation_mask, image_scale, self._output_size, offset)
else:
segmentation_mask = tf.image.pad_to_bounding_box(
segmentation_mask, 0, 0,
self._segmentation_groundtruth_padded_size[0],
self._segmentation_groundtruth_padded_size[1])
segmentation_mask -= 1
# Assign ignore label to the padded region.
segmentation_mask = tf.where(
tf.equal(segmentation_mask, -1),
self._segmentation_ignore_label * tf.ones_like(segmentation_mask),
segmentation_mask)
segmentation_mask = tf.squeeze(segmentation_mask, axis=0)
segmentation_mask = _process_mask(
data['groundtruth_segmentation_mask'],
self._segmentation_ignore_label, image_info)
segmentation_valid_mask = tf.not_equal(
segmentation_mask, self._segmentation_ignore_label)
labels['groundtruths'].update({
'gt_segmentation_mask': segmentation_mask,
'gt_segmentation_valid_mask': segmentation_valid_mask})
if self._include_eval_masks:
panoptic_category_mask = _process_mask(
data['groundtruth_panoptic_category_mask'],
self._panoptic_ignore_label, image_info)
panoptic_instance_mask = _process_mask(
data['groundtruth_panoptic_instance_mask'],
self._panoptic_ignore_label, image_info)
labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask,
'gt_panoptic_instance_mask': panoptic_instance_mask})
return image, labels
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment