Unverified Commit 895e68a0 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

generate category_mask and instance_mask for PQ evaluator

parent 6a3d7567
......@@ -61,6 +61,10 @@ flags.DEFINE_string('panoptic_annotations_file', '', 'File containing panoptic '
'annotations.')
flags.DEFINE_string('panoptic_masks_dir', '', 'Directory containing panoptic '
'masks.')
flags.DEFINE_boolean(
'include_panoptic_eval_masks', False, 'Whether to include category and '
'instance masks in the result. These are required to run the PQ evaluator '
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
......@@ -70,6 +74,7 @@ logger = tf.get_logger()
logger.setLevel(logging.INFO)
_VOID_LABEL = 0
_VOID_INSTANCE_ID = 0
_THING_CLASS_ID = 1
_STUFF_CLASSES_OFFSET = 90
......@@ -83,27 +88,50 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
return tfrecord_lib.encode_mask_as_png(binary_mask)
def coco_panoptic_segmentation_to_mask_png(segments_info, mask_path,
is_category_thing):
def generate_coco_panoptics_masks(segments_info, mask_path,
include_panoptic_eval_masks,
is_category_thing):
rgb_mask = tfrecord_lib.read_image(mask_path)
r, g, b = np.split(rgb_mask, 3, axis=-1)
# decode rgb encoded panoptic mask to get segments ids
# refer https://cocodataset.org/#format-data
segments_encoded_mask = (r + g * 256 + b * (256**2)).squeeze()
category_mask = np.ones_like(segments_encoded_mask) * _VOID_LABEL
for segment in segments_info:
semantic_segmentation_mask = np.ones_like(segments_encoded_mask) * _VOID_LABEL
if include_panoptic_eval_masks:
category_mask = np.ones_like(segments_encoded_mask) * _VOID_LABEL
instance_mask = np.ones_like(segments_encoded_mask) * _VOID_INSTANCE_ID
for idx, segment in enumerate(segments_info):
segment_id = segment['id']
category_id = segment['category_id']
if is_category_thing[category_id]:
category_id = _THING_CLASS_ID
encoded_category_id = _THING_CLASS_ID
instance_id = idx + 1
else:
category_id -= _STUFF_CLASSES_OFFSET
category_mask[segments_encoded_mask == segment_id] = category_id
return tfrecord_lib.encode_mask_as_png(category_mask)
encoded_category_id = category_id - _STUFF_CLASSES_OFFSET
instance_id = _VOID_INSTANCE_ID
segment_mask = (segments_encoded_mask == segment_id)
semantic_segmentation_mask[segment_mask] = encoded_category_id
if include_panoptic_eval_masks:
category_mask[segment_mask] = category_id
instance_mask[segment_mask] = instance_id
outputs = {
'semantic_segmentation_mask': tfrecord_lib.encode_mask_as_png(
semantic_segmentation_mask)
}
if include_panoptic_eval_masks:
outputs.update({
'category_mask': tfrecord_lib.encode_mask_as_png(category_mask),
'instance_mask': tfrecord_lib.encode_mask_as_png(instance_mask)
})
return outputs
def coco_annotations_to_lists(bbox_annotations, id_to_name_map,
......@@ -195,6 +223,7 @@ def create_tf_example(image,
caption_annotations=None,
panoptic_annotation=None,
is_category_thing=None,
include_panoptic_eval_masks=False,
include_masks=False):
"""Converts image and annotations to a tf.Example proto.
......@@ -271,11 +300,20 @@ def create_tf_example(image,
panoptic_mask_filename = os.path.join(
panoptic_masks_dir,
panoptic_annotation['file_name'])
encoded_panoptic_mask_png = coco_panoptic_segmentation_to_mask_png(
segments_info, panoptic_mask_filename, is_category_thing)
encoded_panoptic_masks = generate_coco_panoptics_masks(
segments_info, panoptic_mask_filename, include_panoptic_eval_masks,
is_category_thing)
feature_dict.update(
{'image/segmentation/class/encoded': tfrecord_lib.convert_to_feature(
encoded_panoptic_mask_png)})
encoded_panoptic_masks['semantic_segmentation_mask'])})
if include_panoptic_eval_masks:
feature_dict.update({
'image/panoptic/category_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['category_mask']),
'image/panoptic/instance_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['instance_mask'])
})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped
......@@ -368,6 +406,7 @@ def generate_annotations(images, image_dirs,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=None,
include_panoptic_eval_masks=False,
include_masks=False):
"""Generator for COCO annotations."""
for image in images:
......@@ -381,7 +420,7 @@ def generate_annotations(images, image_dirs,
img_to_panoptic_annotation else None)
yield (image, image_dirs, panoptic_masks_dir, object_annotation,
id_to_name_map, caption_annotaion, panoptic_annotation,
is_category_thing, include_masks)
is_category_thing, include_panoptic_eval_masks, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file,
......@@ -392,6 +431,7 @@ def _create_tf_record_from_coco_annotations(images_info_file,
caption_annotations_file=None,
panoptic_masks_dir=None,
panoptic_annotations_file=None,
include_panoptic_eval_masks=False,
include_masks=False):
"""Loads COCO annotation json files and converts to tf.Record format.
......@@ -407,6 +447,10 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations.
caption_annotations_file: JSON file containing caption annotations.
panoptic_masks_dir: Directory containing panoptic masks.
panoptic_annotations_file: JSON file containing panoptic annotations.
include_panoptic_eval_masks: Whether to include 'category_mask'
and 'instance_mask', which is required by the panoptic quality evaluator.
include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False.
"""
......@@ -429,9 +473,16 @@ def _create_tf_record_from_coco_annotations(images_info_file,
_load_panoptic_annotations(panoptic_annotations_file))
coco_annotations_iter = generate_annotations(
images, image_dirs, panoptic_masks_dir, img_to_obj_annotation,
img_to_caption_annotation, img_to_panoptic_annotation, is_category_thing,
id_to_name_map=id_to_name_map, include_masks=include_masks)
images=images,
image_dirs=image_dirs,
panoptic_masks_dir=panoptic_masks_dir,
img_to_obj_annotation=img_to_obj_annotation,
img_to_caption_annotation=img_to_caption_annotation,
img_to_panoptic_annotation=img_to_panoptic_annotation,
is_category_thing=is_category_thing,
id_to_name_map=id_to_name_map,
include_panoptic_eval_masks=include_panoptic_eval_masks,
include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, num_shards)
......@@ -462,6 +513,7 @@ def main(_):
FLAGS.caption_annotations_file,
FLAGS.panoptic_masks_dir,
FLAGS.panoptic_annotations_file,
FLAGS.include_panoptic_eval_masks,
FLAGS.include_masks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment