Commit 721a6fa4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480986311
parent c60499b1
...@@ -48,18 +48,24 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder): ...@@ -48,18 +48,24 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
panoptic_category_mask_key: panoptic_category_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value=''), tf.io.FixedLenFeature((), tf.string, default_value=''),
panoptic_instance_mask_key: panoptic_instance_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value='')}) tf.io.FixedLenFeature((), tf.string, default_value='')
})
self._segmentation_keys_to_features = keys_to_features self._segmentation_keys_to_features = keys_to_features
def decode_segmentation_mask(self, parsed_tensors):
segmentation_mask = tf.io.decode_image(
parsed_tensors['image/segmentation/class/encoded'], channels=1)
segmentation_mask.set_shape([None, None, 1])
return segmentation_mask
def decode(self, serialized_example): def decode(self, serialized_example):
decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example) decoded_tensors = super(TfExampleDecoder, self).decode(serialized_example)
parsed_tensors = tf.io.parse_single_example( parsed_tensors = tf.io.parse_single_example(
serialized_example, self._segmentation_keys_to_features) serialized_example, self._segmentation_keys_to_features)
segmentation_mask = tf.io.decode_image( decoded_tensors.update({
parsed_tensors['image/segmentation/class/encoded'], 'groundtruth_segmentation_mask':
channels=1) self.decode_segmentation_mask(parsed_tensors)
segmentation_mask.set_shape([None, None, 1]) })
decoded_tensors.update({'groundtruth_segmentation_mask': segmentation_mask})
if self._include_panoptic_masks: if self._include_panoptic_masks:
category_mask = tf.io.decode_image( category_mask = tf.io.decode_image(
...@@ -221,18 +227,21 @@ class Parser(maskrcnn_input.Parser): ...@@ -221,18 +227,21 @@ class Parser(maskrcnn_input.Parser):
are supposed to be used in computing the segmentation loss while are supposed to be used in computing the segmentation loss while
training. training.
""" """
# (height, width, num_channels = 1)
# All the operations below support num_channels >= 1.
segmentation_mask = data['groundtruth_segmentation_mask'] segmentation_mask = data['groundtruth_segmentation_mask']
# Flips image randomly during training. # Flips image randomly during training.
if self.aug_rand_hflip: if self.aug_rand_hflip:
masks = data['groundtruth_instance_masks'] masks = data['groundtruth_instance_masks']
num_image_channels = data['image'].shape.as_list()[-1]
image_mask = tf.concat([data['image'], segmentation_mask], axis=2) image_mask = tf.concat([data['image'], segmentation_mask], axis=2)
image_mask, boxes, masks = preprocess_ops.random_horizontal_flip( image_mask, boxes, masks = preprocess_ops.random_horizontal_flip(
image_mask, data['groundtruth_boxes'], masks) image_mask, data['groundtruth_boxes'], masks)
segmentation_mask = image_mask[:, :, -1:] image = image_mask[:, :, :num_image_channels]
image = image_mask[:, :, :-1] segmentation_mask = image_mask[:, :, num_image_channels:]
data['image'] = image data['image'] = image
data['groundtruth_boxes'] = boxes data['groundtruth_boxes'] = boxes
...@@ -244,14 +253,14 @@ class Parser(maskrcnn_input.Parser): ...@@ -244,14 +253,14 @@ class Parser(maskrcnn_input.Parser):
image_scale = image_info[2, :] image_scale = image_info[2, :]
offset = image_info[3, :] offset = image_info[3, :]
segmentation_mask = tf.reshape( # (height, width, num_channels = 1)
segmentation_mask, shape=[1, data['height'], data['width']])
segmentation_mask = tf.cast(segmentation_mask, tf.float32) segmentation_mask = tf.cast(segmentation_mask, tf.float32)
# Pad label and make sure the padded region assigned to the ignore label. # Pad label and make sure the padded region assigned to the ignore label.
# The label is first offset by +1 and then padded with 0. # The label is first offset by +1 and then padded with 0.
segmentation_mask += 1 segmentation_mask += 1
segmentation_mask = tf.expand_dims(segmentation_mask, axis=3) # (1, height, width, num_channels = 1)
segmentation_mask = tf.expand_dims(segmentation_mask, axis=0)
segmentation_mask = preprocess_ops.resize_and_crop_masks( segmentation_mask = preprocess_ops.resize_and_crop_masks(
segmentation_mask, image_scale, self._output_size, offset) segmentation_mask, image_scale, self._output_size, offset)
segmentation_mask -= 1 segmentation_mask -= 1
...@@ -259,6 +268,7 @@ class Parser(maskrcnn_input.Parser): ...@@ -259,6 +268,7 @@ class Parser(maskrcnn_input.Parser):
tf.equal(segmentation_mask, -1), tf.equal(segmentation_mask, -1),
self._segmentation_ignore_label * tf.ones_like(segmentation_mask), self._segmentation_ignore_label * tf.ones_like(segmentation_mask),
segmentation_mask) segmentation_mask)
# (height, width, num_channels = 1)
segmentation_mask = tf.squeeze(segmentation_mask, axis=0) segmentation_mask = tf.squeeze(segmentation_mask, axis=0)
segmentation_valid_mask = tf.not_equal( segmentation_valid_mask = tf.not_equal(
segmentation_mask, self._segmentation_ignore_label) segmentation_mask, self._segmentation_ignore_label)
...@@ -291,9 +301,13 @@ class Parser(maskrcnn_input.Parser): ...@@ -291,9 +301,13 @@ class Parser(maskrcnn_input.Parser):
shape [height_l, width_l, 4] representing anchor boxes at each shape [height_l, width_l, 4] representing anchor boxes at each
level. level.
""" """
def _process_mask(mask, ignore_label, image_info): def _process_mask(mask, ignore_label, image_info):
# (height, width, num_channels = 1)
# All the operations below support num_channels >= 1.
mask = tf.cast(mask, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32)
mask = tf.reshape(mask, shape=[1, data['height'], data['width'], 1]) # (1, height, width, num_channels = 1)
mask = tf.expand_dims(mask, axis=0)
mask += 1 mask += 1
if self._segmentation_resize_eval_groundtruth: if self._segmentation_resize_eval_groundtruth:
...@@ -314,12 +328,14 @@ class Parser(maskrcnn_input.Parser): ...@@ -314,12 +328,14 @@ class Parser(maskrcnn_input.Parser):
tf.equal(mask, -1), tf.equal(mask, -1),
ignore_label * tf.ones_like(mask), ignore_label * tf.ones_like(mask),
mask) mask)
# (height, width, num_channels = 1)
mask = tf.squeeze(mask, axis=0) mask = tf.squeeze(mask, axis=0)
return mask return mask
image, labels = super(Parser, self)._parse_eval_data(data) image, labels = super(Parser, self)._parse_eval_data(data)
image_info = labels['image_info'] image_info = labels['image_info']
# (height, width, num_channels = 1)
segmentation_mask = _process_mask( segmentation_mask = _process_mask(
data['groundtruth_segmentation_mask'], data['groundtruth_segmentation_mask'],
self._segmentation_ignore_label, image_info) self._segmentation_ignore_label, image_info)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment