Commit e453835a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10225 from srihari-humbarwadi:coco_tfrecords

PiperOrigin-RevId: 396682327
parents fec0338f 983ffd16
...@@ -58,6 +58,14 @@ flags.DEFINE_string( ...@@ -58,6 +58,14 @@ flags.DEFINE_string(
'annotations - boxes and instance masks.') 'annotations - boxes and instance masks.')
flags.DEFINE_string('caption_annotations_file', '', 'File containing image ' flags.DEFINE_string('caption_annotations_file', '', 'File containing image '
'captions.') 'captions.')
flags.DEFINE_string('panoptic_annotations_file', '', 'File containing panoptic '
'annotations.')
flags.DEFINE_string('panoptic_masks_dir', '',
'Directory containing panoptic masks annotations.')
flags.DEFINE_boolean(
'include_panoptic_masks', False, 'Whether to include category and '
'instance masks in the result. These are required to run the PQ evaluator '
'default: False.')
flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file') flags.DEFINE_string('output_file_prefix', '/tmp/train', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.') flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')
...@@ -66,6 +74,11 @@ FLAGS = flags.FLAGS ...@@ -66,6 +74,11 @@ FLAGS = flags.FLAGS
logger = tf.get_logger() logger = tf.get_logger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
_VOID_LABEL = 0
_VOID_INSTANCE_ID = 0
_THING_CLASS_ID = 1
_STUFF_CLASSES_OFFSET = 90
def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
"""Encode a COCO mask segmentation as PNG string.""" """Encode a COCO mask segmentation as PNG string."""
...@@ -74,12 +87,79 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd): ...@@ -74,12 +87,79 @@ def coco_segmentation_to_mask_png(segmentation, height, width, is_crowd):
if not is_crowd: if not is_crowd:
binary_mask = np.amax(binary_mask, axis=2) binary_mask = np.amax(binary_mask, axis=2)
return tfrecord_lib.encode_binary_mask_as_png(binary_mask) return tfrecord_lib.encode_mask_as_png(binary_mask)
def generate_coco_panoptics_masks(segments_info, mask_path,
include_panoptic_masks,
is_category_thing):
"""Creates masks for panoptic segmentation task.
Args:
segments_info: a list of dicts, where each dict has keys: [u'id',
u'category_id', u'area', u'bbox', u'iscrowd'], detailing information for
each segment in the panoptic mask.
mask_path: path to the panoptic mask.
include_panoptic_masks: bool, when set to True, category and instance
masks are included in the outputs. Set this to True, when using
the Panoptic Quality evaluator.
is_category_thing: a dict with category ids as keys and, 0/1 as values to
represent "stuff" and "things" classes respectively.
Returns:
A dict with with keys: [u'semantic_segmentation_mask', u'category_mask',
u'instance_mask']. The dict contains 'category_mask' and 'instance_mask'
only if `include_panoptic_eval_masks` is set to True.
"""
rgb_mask = tfrecord_lib.read_image(mask_path)
r, g, b = np.split(rgb_mask, 3, axis=-1)
# decode rgb encoded panoptic mask to get segments ids
# refer https://cocodataset.org/#format-data
segments_encoded_mask = (r + g * 256 + b * (256**2)).squeeze()
semantic_segmentation_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
if include_panoptic_masks:
category_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_LABEL
instance_mask = np.ones_like(
segments_encoded_mask, dtype=np.uint8) * _VOID_INSTANCE_ID
for idx, segment in enumerate(segments_info):
segment_id = segment['id']
category_id = segment['category_id']
if is_category_thing[category_id]:
encoded_category_id = _THING_CLASS_ID
instance_id = idx + 1
else:
encoded_category_id = category_id - _STUFF_CLASSES_OFFSET
instance_id = _VOID_INSTANCE_ID
segment_mask = (segments_encoded_mask == segment_id)
semantic_segmentation_mask[segment_mask] = encoded_category_id
if include_panoptic_masks:
category_mask[segment_mask] = category_id
instance_mask[segment_mask] = instance_id
outputs = {
'semantic_segmentation_mask': tfrecord_lib.encode_mask_as_png(
semantic_segmentation_mask)
}
if include_panoptic_masks:
outputs.update({
'category_mask': tfrecord_lib.encode_mask_as_png(category_mask),
'instance_mask': tfrecord_lib.encode_mask_as_png(instance_mask)
})
return outputs
def coco_annotations_to_lists(bbox_annotations, id_to_name_map, def coco_annotations_to_lists(bbox_annotations, id_to_name_map,
image_height, image_width, include_masks): image_height, image_width, include_masks):
"""Convert COCO annotations to feature lists.""" """Converts COCO annotations to feature lists."""
data = dict((k, list()) for k in data = dict((k, list()) for k in
['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd', ['xmin', 'xmax', 'ymin', 'ymax', 'is_crowd',
...@@ -160,9 +240,13 @@ def encode_caption_annotations(caption_annotations): ...@@ -160,9 +240,13 @@ def encode_caption_annotations(caption_annotations):
def create_tf_example(image, def create_tf_example(image,
image_dirs, image_dirs,
panoptic_masks_dir=None,
bbox_annotations=None, bbox_annotations=None,
id_to_name_map=None, id_to_name_map=None,
caption_annotations=None, caption_annotations=None,
panoptic_annotation=None,
is_category_thing=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Converts image and annotations to a tf.Example proto. """Converts image and annotations to a tf.Example proto.
...@@ -170,6 +254,7 @@ def create_tf_example(image, ...@@ -170,6 +254,7 @@ def create_tf_example(image,
image: dict with keys: [u'license', u'file_name', u'coco_url', u'height', image: dict with keys: [u'license', u'file_name', u'coco_url', u'height',
u'width', u'date_captured', u'flickr_url', u'id'] u'width', u'date_captured', u'flickr_url', u'id']
image_dirs: list of directories containing the image files. image_dirs: list of directories containing the image files.
panoptic_masks_dir: `str` of the panoptic masks directory.
bbox_annotations: bbox_annotations:
list of dicts with keys: [u'segmentation', u'area', u'iscrowd', list of dicts with keys: [u'segmentation', u'area', u'iscrowd',
u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box
...@@ -182,6 +267,11 @@ def create_tf_example(image, ...@@ -182,6 +267,11 @@ def create_tf_example(image,
id_to_name_map: a dict mapping category IDs to string names. id_to_name_map: a dict mapping category IDs to string names.
caption_annotations: caption_annotations:
list of dict with keys: [u'id', u'image_id', u'str']. list of dict with keys: [u'id', u'image_id', u'str'].
panoptic_annotation: dict with keys: [u'image_id', u'file_name',
u'segments_info']. Where the value for segments_info is a list of dicts,
with each dict containing information for a single segment in the mask.
is_category_thing: `bool`, whether it is a category thing.
include_panoptic_masks: `bool`, whether to include panoptic masks.
include_masks: Whether to include instance segmentations masks include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False. (PNG encoded) in the result. default: False.
...@@ -234,6 +324,26 @@ def create_tf_example(image, ...@@ -234,6 +324,26 @@ def create_tf_example(image,
feature_dict.update( feature_dict.update(
{'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)}) {'image/caption': tfrecord_lib.convert_to_feature(encoded_captions)})
if panoptic_annotation:
segments_info = panoptic_annotation['segments_info']
panoptic_mask_filename = os.path.join(
panoptic_masks_dir,
panoptic_annotation['file_name'])
encoded_panoptic_masks = generate_coco_panoptics_masks(
segments_info, panoptic_mask_filename, include_panoptic_masks,
is_category_thing)
feature_dict.update(
{'image/segmentation/class/encoded': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['semantic_segmentation_mask'])})
if include_panoptic_masks:
feature_dict.update({
'image/panoptic/category_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['category_mask']),
'image/panoptic/instance_mask': tfrecord_lib.convert_to_feature(
encoded_panoptic_masks['instance_mask'])
})
example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
return example, num_annotations_skipped return example, num_annotations_skipped
...@@ -287,6 +397,33 @@ def _load_caption_annotations(caption_annotations_file): ...@@ -287,6 +397,33 @@ def _load_caption_annotations(caption_annotations_file):
return img_to_caption_annotation return img_to_caption_annotation
def _load_panoptic_annotations(panoptic_annotations_file):
"""Loads panoptic annotation from file."""
with tf.io.gfile.GFile(panoptic_annotations_file, 'r') as fid:
panoptic_annotations = json.load(fid)
img_to_panoptic_annotation = dict()
logging.info('Building panoptic index.')
for annotation in panoptic_annotations['annotations']:
image_id = annotation['image_id']
img_to_panoptic_annotation[image_id] = annotation
is_category_thing = dict()
for category_info in panoptic_annotations['categories']:
is_category_thing[category_info['id']] = category_info['isthing'] == 1
missing_annotation_count = 0
images = panoptic_annotations['images']
for image in images:
image_id = image['id']
if image_id not in img_to_panoptic_annotation:
missing_annotation_count += 1
logging.info(
'%d images are missing panoptic annotations.', missing_annotation_count)
return img_to_panoptic_annotation, is_category_thing
def _load_images_info(images_info_file): def _load_images_info(images_info_file):
with tf.io.gfile.GFile(images_info_file, 'r') as fid: with tf.io.gfile.GFile(images_info_file, 'r') as fid:
info_dict = json.load(fid) info_dict = json.load(fid)
...@@ -294,11 +431,15 @@ def _load_images_info(images_info_file): ...@@ -294,11 +431,15 @@ def _load_images_info(images_info_file):
def generate_annotations(images, image_dirs, def generate_annotations(images, image_dirs,
panoptic_masks_dir=None,
img_to_obj_annotation=None, img_to_obj_annotation=None,
img_to_caption_annotation=None, id_to_name_map=None, img_to_caption_annotation=None,
img_to_panoptic_annotation=None,
is_category_thing=None,
id_to_name_map=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Generator for COCO annotations.""" """Generator for COCO annotations."""
for image in images: for image in images:
object_annotation = (img_to_obj_annotation.get(image['id'], None) if object_annotation = (img_to_obj_annotation.get(image['id'], None) if
img_to_obj_annotation else None) img_to_obj_annotation else None)
...@@ -306,8 +447,11 @@ def generate_annotations(images, image_dirs, ...@@ -306,8 +447,11 @@ def generate_annotations(images, image_dirs,
caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if caption_annotaion = (img_to_caption_annotation.get(image['id'], None) if
img_to_caption_annotation else None) img_to_caption_annotation else None)
yield (image, image_dirs, object_annotation, id_to_name_map, panoptic_annotation = (img_to_panoptic_annotation.get(image['id'], None) if
caption_annotaion, include_masks) img_to_panoptic_annotation else None)
yield (image, image_dirs, panoptic_masks_dir, object_annotation,
id_to_name_map, caption_annotaion, panoptic_annotation,
is_category_thing, include_panoptic_masks, include_masks)
def _create_tf_record_from_coco_annotations(images_info_file, def _create_tf_record_from_coco_annotations(images_info_file,
...@@ -316,6 +460,9 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -316,6 +460,9 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards, num_shards,
object_annotations_file=None, object_annotations_file=None,
caption_annotations_file=None, caption_annotations_file=None,
panoptic_masks_dir=None,
panoptic_annotations_file=None,
include_panoptic_masks=False,
include_masks=False): include_masks=False):
"""Loads COCO annotation json files and converts to tf.Record format. """Loads COCO annotation json files and converts to tf.Record format.
...@@ -331,6 +478,10 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -331,6 +478,10 @@ def _create_tf_record_from_coco_annotations(images_info_file,
num_shards: Number of output files to create. num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations. object_annotations_file: JSON file containing bounding box annotations.
caption_annotations_file: JSON file containing caption annotations. caption_annotations_file: JSON file containing caption annotations.
panoptic_masks_dir: Directory containing panoptic masks.
panoptic_annotations_file: JSON file containing panoptic annotations.
include_panoptic_masks: Whether to include 'category_mask'
and 'instance_mask', which is required by the panoptic quality evaluator.
include_masks: Whether to include instance segmentations masks include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False. (PNG encoded) in the result. default: False.
""" """
...@@ -342,16 +493,29 @@ def _create_tf_record_from_coco_annotations(images_info_file, ...@@ -342,16 +493,29 @@ def _create_tf_record_from_coco_annotations(images_info_file,
img_to_obj_annotation = None img_to_obj_annotation = None
img_to_caption_annotation = None img_to_caption_annotation = None
id_to_name_map = None id_to_name_map = None
img_to_panoptic_annotation = None
is_category_thing = None
if object_annotations_file: if object_annotations_file:
img_to_obj_annotation, id_to_name_map = ( img_to_obj_annotation, id_to_name_map = (
_load_object_annotations(object_annotations_file)) _load_object_annotations(object_annotations_file))
if caption_annotations_file: if caption_annotations_file:
img_to_caption_annotation = ( img_to_caption_annotation = (
_load_caption_annotations(caption_annotations_file)) _load_caption_annotations(caption_annotations_file))
if panoptic_annotations_file:
img_to_panoptic_annotation, is_category_thing = (
_load_panoptic_annotations(panoptic_annotations_file))
coco_annotations_iter = generate_annotations( coco_annotations_iter = generate_annotations(
images, image_dirs, img_to_obj_annotation, img_to_caption_annotation, images=images,
id_to_name_map=id_to_name_map, include_masks=include_masks) image_dirs=image_dirs,
panoptic_masks_dir=panoptic_masks_dir,
img_to_obj_annotation=img_to_obj_annotation,
img_to_caption_annotation=img_to_caption_annotation,
img_to_panoptic_annotation=img_to_panoptic_annotation,
is_category_thing=is_category_thing,
id_to_name_map=id_to_name_map,
include_panoptic_masks=include_panoptic_masks,
include_masks=include_masks)
num_skipped = tfrecord_lib.write_tf_record_dataset( num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, coco_annotations_iter, create_tf_example, num_shards) output_path, coco_annotations_iter, create_tf_example, num_shards)
...@@ -380,6 +544,9 @@ def main(_): ...@@ -380,6 +544,9 @@ def main(_):
FLAGS.num_shards, FLAGS.num_shards,
FLAGS.object_annotations_file, FLAGS.object_annotations_file,
FLAGS.caption_annotations_file, FLAGS.caption_annotations_file,
FLAGS.panoptic_masks_dir,
FLAGS.panoptic_annotations_file,
FLAGS.include_panoptic_masks,
FLAGS.include_masks) FLAGS.include_masks)
......
...@@ -100,8 +100,13 @@ def image_info_to_feature_dict(height, width, filename, image_id, ...@@ -100,8 +100,13 @@ def image_info_to_feature_dict(height, width, filename, image_id,
} }
def encode_binary_mask_as_png(binary_mask): def read_image(image_path):
pil_image = Image.fromarray(binary_mask) pil_image = Image.open(image_path)
return np.asarray(pil_image)
def encode_mask_as_png(mask):
pil_image = Image.fromarray(mask)
output_io = io.BytesIO() output_io = io.BytesIO()
pil_image.save(output_io, format='PNG') pil_image.save(output_io, format='PNG')
return output_io.getvalue() return output_io.getvalue()
......
...@@ -21,6 +21,7 @@ from typing import List, Optional ...@@ -21,6 +21,7 @@ from typing import List, Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import maskrcnn from official.vision.beta.configs import maskrcnn
from official.vision.beta.configs import semantic_segmentation from official.vision.beta.configs import semantic_segmentation
...@@ -46,11 +47,28 @@ class Parser(maskrcnn.Parser): ...@@ -46,11 +47,28 @@ class Parser(maskrcnn.Parser):
segmentation_groundtruth_padded_size: List[int] = dataclasses.field( segmentation_groundtruth_padded_size: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
segmentation_ignore_label: int = 255 segmentation_ignore_label: int = 255
panoptic_ignore_label: int = 0
# Setting this to true will enable parsing category_mask and instance_mask.
include_panoptic_masks: bool = True
@dataclasses.dataclass
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask.
include_panoptic_masks: bool = True
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: TfExampleDecoder = TfExampleDecoder()
@dataclasses.dataclass @dataclasses.dataclass
class DataConfig(maskrcnn.DataConfig): class DataConfig(maskrcnn.DataConfig):
"""Input config for training.""" """Input config for training."""
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser() parser: Parser = Parser()
......
...@@ -24,25 +24,51 @@ from official.vision.beta.ops import preprocess_ops ...@@ -24,25 +24,51 @@ from official.vision.beta.ops import preprocess_ops
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder.""" """Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id, mask_binarize_threshold): def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_panoptic_masks):
super(TfExampleDecoder, self).__init__( super(TfExampleDecoder, self).__init__(
include_mask=True, include_mask=True,
regenerate_source_id=regenerate_source_id, regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None) mask_binarize_threshold=None)
self._segmentation_keys_to_features = {
self._include_panoptic_masks = include_panoptic_masks
keys_to_features = {
'image/segmentation/class/encoded': 'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='') tf.io.FixedLenFeature((), tf.string, default_value='')}
}
if include_panoptic_masks:
keys_to_features.update({
'image/panoptic/category_mask':
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask':
tf.io.FixedLenFeature((), tf.string, default_value='')})
self._segmentation_keys_to_features = keys_to_features
def decode(self, serialized_example): def decode(self, serialized_example):
decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example) decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example)
segmentation_parsed_tensors = tf.io.parse_single_example( parsed_tensors = tf.io.parse_single_example(
serialized_example, self._segmentation_keys_to_features) serialized_example, self._segmentation_keys_to_features)
segmentation_mask = tf.io.decode_image( segmentation_mask = tf.io.decode_image(
segmentation_parsed_tensors['image/segmentation/class/encoded'], parsed_tensors['image/segmentation/class/encoded'],
channels=1) channels=1)
segmentation_mask.set_shape([None, None, 1]) segmentation_mask.set_shape([None, None, 1])
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask}) decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_panoptic_masks:
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'],
channels=1)
instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'],
channels=1)
category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1])
decoded_tensors.update({
'groundtruth_panoptic_category_mask':
category_mask,
'groundtruth_panoptic_instance_mask':
instance_mask})
return decoded_tensors return decoded_tensors
...@@ -69,6 +95,8 @@ class Parser(maskrcnn_input.Parser): ...@@ -69,6 +95,8 @@ class Parser(maskrcnn_input.Parser):
segmentation_resize_eval_groundtruth=True, segmentation_resize_eval_groundtruth=True,
segmentation_groundtruth_padded_size=None, segmentation_groundtruth_padded_size=None,
segmentation_ignore_label=255, segmentation_ignore_label=255,
panoptic_ignore_label=0,
include_panoptic_masks=True,
dtype='float32'): dtype='float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
...@@ -106,8 +134,12 @@ class Parser(maskrcnn_input.Parser): ...@@ -106,8 +134,12 @@ class Parser(maskrcnn_input.Parser):
segmentation_groundtruth_padded_size: `Tensor` or `list` for [height, segmentation_groundtruth_padded_size: `Tensor` or `list` for [height,
width]. When resize_eval_groundtruth is set to False, the groundtruth width]. When resize_eval_groundtruth is set to False, the groundtruth
masks are padded to this size. masks are padded to this size.
segmentation_ignore_label: `int` the pixel with ignore label will not used segmentation_ignore_label: `int` the pixels with ignore label will not be
for training and evaluation. used for training and evaluation.
panoptic_ignore_label: `int` the pixels with ignore label will not be used
by the PQ evaluator.
include_panoptic_masks: `bool`, if True, category_mask and instance_mask
will be parsed. Set this to true if PQ evaluator is enabled.
dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}. dtype: `str`, data type. One of {`bfloat16`, `float32`, `float16`}.
""" """
super(Parser, self).__init__( super(Parser, self).__init__(
...@@ -139,6 +171,8 @@ class Parser(maskrcnn_input.Parser): ...@@ -139,6 +171,8 @@ class Parser(maskrcnn_input.Parser):
'specified when segmentation_resize_eval_groundtruth is False.') 'specified when segmentation_resize_eval_groundtruth is False.')
self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size self._segmentation_groundtruth_padded_size = segmentation_groundtruth_padded_size
self._segmentation_ignore_label = segmentation_ignore_label self._segmentation_ignore_label = segmentation_ignore_label
self._panoptic_ignore_label = panoptic_ignore_label
self._include_panoptic_masks = include_panoptic_masks
def _parse_train_data(self, data): def _parse_train_data(self, data):
"""Parses data for training. """Parses data for training.
...@@ -250,39 +284,54 @@ class Parser(maskrcnn_input.Parser): ...@@ -250,39 +284,54 @@ class Parser(maskrcnn_input.Parser):
shape [height_l, width_l, 4] representing anchor boxes at each shape [height_l, width_l, 4] representing anchor boxes at each
level. level.
""" """
segmentation_mask = tf.cast( def _process_mask(mask, ignore_label, image_info):
data['groundtruth_segmentation_mask'], tf.float32) mask = tf.cast(mask, dtype=tf.float32)
segmentation_mask = tf.reshape( mask = tf.reshape(mask, shape=[1, data['height'], data['width'], 1])
segmentation_mask, shape=[1, data['height'], data['width'], 1]) mask += 1
segmentation_mask += 1
image, labels = super(Parser, self)._parse_eval_data(data)
if self._segmentation_resize_eval_groundtruth: if self._segmentation_resize_eval_groundtruth:
# Resizes eval masks to match input image sizes. In that case, mean IoU # Resizes eval masks to match input image sizes. In that case, mean IoU
# is computed on output_size not the original size of the images. # is computed on output_size not the original size of the images.
image_info = labels['image_info']
image_scale = image_info[2, :] image_scale = image_info[2, :]
offset = image_info[3, :] offset = image_info[3, :]
segmentation_mask = preprocess_ops.resize_and_crop_masks( mask = preprocess_ops.resize_and_crop_masks(
segmentation_mask, image_scale, self._output_size, offset) mask, image_scale, self._output_size, offset)
else: else:
segmentation_mask = tf.image.pad_to_bounding_box( mask = tf.image.pad_to_bounding_box(
segmentation_mask, 0, 0, mask, 0, 0,
self._segmentation_groundtruth_padded_size[0], self._segmentation_groundtruth_padded_size[0],
self._segmentation_groundtruth_padded_size[1]) self._segmentation_groundtruth_padded_size[1])
mask -= 1
segmentation_mask -= 1
# Assign ignore label to the padded region. # Assign ignore label to the padded region.
segmentation_mask = tf.where( mask = tf.where(
tf.equal(segmentation_mask, -1), tf.equal(mask, -1),
self._segmentation_ignore_label * tf.ones_like(segmentation_mask), ignore_label * tf.ones_like(mask),
segmentation_mask) mask)
segmentation_mask = tf.squeeze(segmentation_mask, axis=0) mask = tf.squeeze(mask, axis=0)
return mask
image, labels = super(Parser, self)._parse_eval_data(data)
image_info = labels['image_info']
segmentation_mask = _process_mask(
data['groundtruth_segmentation_mask'],
self._segmentation_ignore_label, image_info)
segmentation_valid_mask = tf.not_equal( segmentation_valid_mask = tf.not_equal(
segmentation_mask, self._segmentation_ignore_label) segmentation_mask, self._segmentation_ignore_label)
labels['groundtruths'].update({ labels['groundtruths'].update({
'gt_segmentation_mask': segmentation_mask, 'gt_segmentation_mask': segmentation_mask,
'gt_segmentation_valid_mask': segmentation_valid_mask}) 'gt_segmentation_valid_mask': segmentation_valid_mask})
if self._include_panoptic_masks:
panoptic_category_mask = _process_mask(
data['groundtruth_panoptic_category_mask'],
self._panoptic_ignore_label, image_info)
panoptic_instance_mask = _process_mask(
data['groundtruth_panoptic_instance_mask'],
self._panoptic_ignore_label, image_info)
labels['groundtruths'].update({
'gt_panoptic_category_mask': panoptic_category_mask,
'gt_panoptic_instance_mask': panoptic_instance_mask})
return image, labels return image, labels
...@@ -121,7 +121,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -121,7 +121,8 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
decoder = panoptic_maskrcnn_input.TfExampleDecoder( decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id, regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold) mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_panoptic_masks=decoder_cfg.include_panoptic_masks)
else: else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type)) raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
...@@ -147,7 +148,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask): ...@@ -147,7 +148,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
.segmentation_resize_eval_groundtruth, .segmentation_resize_eval_groundtruth,
segmentation_groundtruth_padded_size=params.parser segmentation_groundtruth_padded_size=params.parser
.segmentation_groundtruth_padded_size, .segmentation_groundtruth_padded_size,
segmentation_ignore_label=params.parser.segmentation_ignore_label) segmentation_ignore_label=params.parser.segmentation_ignore_label,
panoptic_ignore_label=params.parser.panoptic_ignore_label,
include_panoptic_masks=params.parser.include_panoptic_masks)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment