Unverified Commit 10ea7bdd authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

added coco-panoptic groundtruths to tf_example

parent 4edd2a73
...@@ -39,7 +39,6 @@ import numpy as np ...@@ -39,7 +39,6 @@ import numpy as np
from pycocotools import mask from pycocotools import mask
import tensorflow as tf import tensorflow as tf
import multiprocessing as mp
from official.vision.beta.data import tfrecord_lib from official.vision.beta.data import tfrecord_lib
...@@ -58,6 +57,10 @@ flags.DEFINE_string( ...@@ -58,6 +57,10 @@ flags.DEFINE_string(
'annotations - boxes and instance masks.') 'annotations - boxes and instance masks.')
flags.DEFINE_string('caption_annotations_file', '', 'File containing image ' flags.DEFINE_string('caption_annotations_file', '', 'File containing image '
'captions.') 'captions.')
flags.DEFINE_string('panoptic_annotations_file', '', 'File containing panoptic '
'annotations.')
flags.DEFINE_string('panoptic_masks_dir', '', 'Directory containing panoptic '
'masks.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file') flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.') flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
...@@ -66,6 +69,9 @@ FLAGS = flags.FLAGS ...@@ -66,6 +69,9 @@ FLAGS = flags.FLAGS
logger = tf.get_logger() logger = tf.get_logger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
_VOID_LABEL = 0
_THING_CLASS_ID = 1
_STUFF_CLASSES_OFFSET = 90
def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
"""Encode a COCO mask segmentation as PNG string.""" """Encode a COCO mask segmentation as PNG string."""
...@@ -74,7 +80,30 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): ...@@ -74,7 +80,30 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
if not is_crowd: if not is_crowd:
binary_mask = np.amax(binary_mask, axis=2) binary_mask = np.amax(binary_mask, axis=2)
return tfrecord_lib.encode_binary_mask_as_png(binary_mask) return tfrecord_lib.encode_mask_as_png(binary_mask)
def coco_panoptic_segmentation_to_mask_png(segments_info, mask_path,
is_category_thing):
rgb_mask = tfrecord_lib.read_image(mask_path)
r, g, b = np.split(rgb_mask, 3, axis=-1)
# decode rgb encoded panoptic mask to get segments ids
# refer https://cocodataset.org/#format-data
segments_encoded_mask = (r + g * 256 + b * (256**2)).squeeze()
category_mask = np.ones_like(segments_encoded_mask) * _VOID_LABEL
for segment in segments_info:
segment_id = segment['id']
category_id = segment['category_id']
if is_category_thing[category_id]:
category_id = _THING_CLASS_ID
else:
category_id -= _STUFF_CLASSES_OFFSET
category_mask[segments_encoded_mask == segment_id] = category_id
return tfrecord_lib.encode_mask_as_png(category_mask)
def coco_annotations_to_lists(bbox_annotations, id_to_name_map, def coco_annotations_to_lists(bbox_annotations, id_to_name_map,
...@@ -160,9 +189,12 @@ def encode_caption_annotations(caption_annotations): ...@@ -160,9 +189,12 @@ def encode_caption_annotations(caption_annotations):
def create_tf_example(image, def create_tf_example(image,
image_dirs, image_dirs,
panoptic_masks_dir=None,
bbox_annotations=None, bbox_annotations=None,
id_to_name_map=None, id_to_name_map=None,
caption_annotations=None, caption_annotations=None,
panoptic_annotation=None,
is_category_thing=None,
include_masks=False): include_masks=False):
"""Converts image and annotations to a tf.Example proto. """Converts image and annotations to a tf.Example proto.
...@@ -234,6 +266,17 @@ def create_tf_example(image, ...@@ -234,6 +266,17 @@ def create_tf_example(image,
feature_dict.update( feature_dict.update(
{'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)}) {'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)})
if panoptic_annotation:
segments_info = panoptic_annotation['segments_info']
panoptic_mask_filename = os.path.join(
panoptic_masks_dir,
panoptic_annotation['file_name'])
encoded_panoptic_mask_png = coco_panoptic_segmentation_to_mask_png(
segments_info, panoptic_mask_filename, is_category_thing)
feature_dict.update(
{'image/segmentation/class/encoded': tfrecord_lib.convert_to_feature(
encoded_panoptic_mask_png)})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped return example, num_annotations_skipped
...@@ -287,6 +330,30 @@ def _load_caption_annotations(caption_annotations_file): ...@@ -287,6 +330,30 @@ def _load_caption_annotations(caption_annotations_file):
return img_to_caption_annotation return img_to_caption_annotation
def _load_panoptic_annotations(panoptic_annotations_file):
with tf.io.gfile.GFile(panoptic_annotations_file, 'r') as fid:
panoptic_annotations = json.load(fid)
img_to_panoptic_annotation = dict()
logging.info('Building panoptic index.')
for annotation in panoptic_annotations['annotations']:
image_id = annotation['image_id']
img_to_panoptic_annotation[image_id] = annotation
is_category_thing = dict()
for category_info in panoptic_annotations['categories']:
is_category_thing[category_info['id']] = category_info['isthing'] == 1
missing_annotation_count = 0
images = panoptic_annotations['images']
for image in images:
image_id = image['id']
if image_id not in img_to_panoptic_annotation:
missing_annotation_count += 1
logging.info('%d images are missing captions.', missing_annotation_count)
return img_to_panoptic_annotation, is_category_thing
def _load_images_info(images_info_file): def _load_images_info(images_info_file):
with tf.io.gfile.GFile(images_info_file, 'r') as fid: with tf.io.gfile.GFile(images_info_file, 'r') as fid:
info_dict = json.load(fid) info_dict = json.load(fid)
...@@ -294,11 +361,14 @@ def _load_images_info(images_info_file): ...@@ -294,11 +361,14 @@ def _load_images_info(images_info_file):
def generate_annotations(images, image_dirs, def generate_annotations(images, image_dirs,
panoptic_masks_dir=None,
img_to_obj_annotation=None, img_to_obj_annotation=None,
img_to_caption_annotation=None, id_to_name_map=None, img_to_caption_annotation=None,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=None,
include_masks=False): include_masks=False):
"""Generator for COCO annotations.""" """Generator for COCO annotations."""
for image in images: for image in images:
object_annotation = (img_to_obj_annotation.get(image['id'], None) if object_annotation = (img_to_obj_annotation.get(image['id'], None) if
img_to_obj_annotation else None) img_to_obj_annotation else None)
...@@ -306,8 +376,11 @@ def generate_annotations(images, image_dirs, ...@@ -306,8 +376,11 @@ def generate_annotations(images, image_dirs,
caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if
img_to_caption_annotation else None) img_to_caption_annotation else None)
yield (image, image_dirs, object_annotation, id_to_name_map, panoptic_annotation = (img_to_panoptic_annotation.get(image['id'], None) if
caption_annotaion, include_masks) img_to_panoptic_annotation else None)
yield (image, image_dirs, panoptic_masks_dir, object_annotation,
id_to_name_map, caption_annotaion, panoptic_annotation,
is_category_thing, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file, def _create_tf_record_from_coco_annotations(images_info_file,
...@@ -316,6 +389,8 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -316,6 +389,8 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards, num_shards,
object_annotations_file=None, object_annotations_file=None,
caption_annotations_file=None, caption_annotations_file=None,
panoptic_masks_dir=None,
panoptic_annotations_file=None,
include_masks=False): include_masks=False):
"""Loads COCO annotation json files and converts to tf.Record format. """Loads COCO annotation json files and converts to tf.Record format.
...@@ -348,9 +423,13 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -348,9 +423,13 @@ def _create_tf_record_from_coco_annotations(images_info_file,
if caption_annotations_file: if caption_annotations_file:
img_to_caption_annotation = ( img_to_caption_annotation = (
_load_caption_annotations(caption_annotations_file)) _load_caption_annotations(caption_annotations_file))
if panoptic_annotations_file:
img_to_panoptic_annotation, is_category_thing = (
_load_panoptic_annotations(panoptic_annotations_file))
coco_annotations_iter = generate_annotations( coco_annotations_iter = generate_annotations(
images, image_dirs, img_to_obj_annotation, img_to_caption_annotation, images, image_dirs, panoptic_masks_dir, img_to_obj_annotation,
img_to_caption_annotation, img_to_panoptic_annotation, is_category_thing,
id_to_name_map=id_to_name_map, include_masks=include_masks) id_to_name_map=id_to_name_map, include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset( num_skipped = tfrecord_lib.write_tf_record_dataset(
...@@ -380,6 +459,8 @@ def main(_): ...@@ -380,6 +459,8 @@ def main(_):
FLAGS.num_shards, FLAGS.num_shards,
FLAGS.object_annotations_file, FLAGS.object_annotations_file,
FLAGS.caption_annotations_file, FLAGS.caption_annotations_file,
FLAGS.panoptic_masks_dir,
FLAGS.panoptic_annotations_file,
FLAGS.include_masks) FLAGS.include_masks)
......
...@@ -99,9 +99,12 @@ def image_info_to_feature_dict(height, width, filename, image_id, ...@@ -99,9 +99,12 @@ def image_info_to_feature_dict(height, width, filename, image_id,
'image/format': convert_to_feature(encoded_format.encode('utf8')), 'image/format': convert_to_feature(encoded_format.encode('utf8')),
} }
def read_image(image_path):
pil_image = Image.open(image_path)
return np.asarray(pil_image)
def encode_binary_mask_as_png(binary_mask): def encode_mask_as_png(mask):
pil_image = Image.fromarray(binary_mask) pil_image = Image.fromarray(mask)
output_io = io.BytesIO() output_io = io.BytesIO()
pil_image.save(output_io, format='PNG') pil_image.save(output_io, format='PNG')
return output_io.getvalue() return output_io.getvalue()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment