Commit ad83b2db authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 460499239
parent 0cda06fa
......@@ -39,8 +39,8 @@ class TfExampleDecoder(decoder.Decoder):
self._regenerate_source_id = regenerate_source_id
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string),
'image/height': tf.io.FixedLenFeature((), tf.int64),
'image/width': tf.io.FixedLenFeature((), tf.int64),
'image/height': tf.io.FixedLenFeature((), tf.int64, -1),
'image/width': tf.io.FixedLenFeature((), tf.int64, -1),
'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
......@@ -148,6 +148,18 @@ class TfExampleDecoder(decoder.Decoder):
boxes = self._decode_boxes(parsed_tensors)
classes = self._decode_classes(parsed_tensors)
areas = self._decode_areas(parsed_tensors)
decode_image_shape = tf.logical_or(
tf.equal(parsed_tensors['image/height'], -1),
tf.equal(parsed_tensors['image/width'], -1))
image_shape = tf.cast(tf.shape(image), dtype=tf.int64)
parsed_tensors['image/height'] = tf.where(decode_image_shape,
image_shape[0],
parsed_tensors['image/height'])
parsed_tensors['image/width'] = tf.where(decode_image_shape, image_shape[1],
parsed_tensors['image/width'])
is_crowds = tf.cond(
tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
......
......@@ -26,18 +26,21 @@ from official.vision.dataloaders import tfexample_utils
class TfExampleDecoderTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
(100, 100, 0, True),
(100, 100, 1, True),
(100, 100, 2, True),
(100, 100, 0, False),
(100, 100, 1, False),
(100, 100, 2, False),
(100, 100, 0, True, True),
(100, 100, 1, True, True),
(100, 100, 2, True, True),
(100, 100, 0, False, True),
(100, 100, 1, False, True),
(100, 100, 2, False, True),
(100, 100, 0, True, False),
(100, 100, 1, True, False),
(100, 100, 2, True, False),
(100, 100, 0, False, False),
(100, 100, 1, False, False),
(100, 100, 2, False, False),
)
def test_result_shape(self,
image_height,
image_width,
num_instances,
regenerate_source_id):
def test_result_shape(self, image_height, image_width, num_instances,
regenerate_source_id, fill_image_size):
decoder = tf_example_decoder.TfExampleDecoder(
include_mask=True, regenerate_source_id=regenerate_source_id)
......@@ -45,7 +48,9 @@ class TfExampleDecoderTest(tf.test.TestCase, parameterized.TestCase):
image_height=image_height,
image_width=image_width,
image_channel=3,
num_instances=num_instances).SerializeToString()
num_instances=num_instances,
fill_image_size=fill_image_size,
).SerializeToString()
decoded_tensors = decoder.decode(
tf.convert_to_tensor(value=serialized_example))
......
......@@ -194,9 +194,12 @@ def create_3d_image_test_example(image_height: int, image_width: int,
return tf.train.Example(features=tf.train.Features(feature=feature))
def create_detection_test_example(image_height: int, image_width: int,
image_channel: int,
num_instances: int) -> tf.train.Example:
def create_detection_test_example(
image_height: int,
image_width: int,
image_channel: int,
num_instances: int,
fill_image_size: bool = True) -> tf.train.Example:
"""Creates and returns a test example containing box and mask annotations.
Args:
......@@ -204,6 +207,7 @@ def create_detection_test_example(image_height: int, image_width: int,
image_width: The width of test image.
image_channel: The channel of test image.
num_instances: The number of object instances per image.
fill_image_size: If image height and width will be added to the example.
Returns:
A tf.train.Example for testing.
......@@ -233,36 +237,41 @@ def create_detection_test_example(image_height: int, image_width: int,
for _ in range(num_instances):
mask = make_image_bytes([image_height, image_width], fmt='PNG')
masks.append(mask)
return tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image]))),
'image/source_id': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[DUMP_SOURCE_ID]))),
'image/height': (tf.train.Feature(
int64_list=tf.train.Int64List(value=[image_height]))),
'image/width': (tf.train.Feature(
int64_list=tf.train.Int64List(value=[image_width]))),
'image/object/bbox/xmin': (tf.train.Feature(
float_list=tf.train.FloatList(value=xmins))),
'image/object/bbox/xmax': (tf.train.Feature(
float_list=tf.train.FloatList(value=xmaxs))),
'image/object/bbox/ymin': (tf.train.Feature(
float_list=tf.train.FloatList(value=ymins))),
'image/object/bbox/ymax': (tf.train.Feature(
float_list=tf.train.FloatList(value=ymaxs))),
'image/object/class/label': (tf.train.Feature(
int64_list=tf.train.Int64List(value=labels))),
'image/object/class/text': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=labels_text))),
'image/object/is_crowd': (tf.train.Feature(
int64_list=tf.train.Int64List(value=is_crowds))),
'image/object/area': (tf.train.Feature(
float_list=tf.train.FloatList(value=areas))),
'image/object/mask': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=masks))),
}))
feature = {
'image/encoded':
(tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))),
'image/source_id': (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[DUMP_SOURCE_ID]))),
'image/object/bbox/xmin':
(tf.train.Feature(float_list=tf.train.FloatList(value=xmins))),
'image/object/bbox/xmax':
(tf.train.Feature(float_list=tf.train.FloatList(value=xmaxs))),
'image/object/bbox/ymin':
(tf.train.Feature(float_list=tf.train.FloatList(value=ymins))),
'image/object/bbox/ymax':
(tf.train.Feature(float_list=tf.train.FloatList(value=ymaxs))),
'image/object/class/label':
(tf.train.Feature(int64_list=tf.train.Int64List(value=labels))),
'image/object/class/text':
(tf.train.Feature(bytes_list=tf.train.BytesList(value=labels_text))),
'image/object/is_crowd':
(tf.train.Feature(int64_list=tf.train.Int64List(value=is_crowds))),
'image/object/area':
(tf.train.Feature(float_list=tf.train.FloatList(value=areas))),
'image/object/mask':
(tf.train.Feature(bytes_list=tf.train.BytesList(value=masks))),
}
if fill_image_size:
feature.update({
'image/height': (tf.train.Feature(
int64_list=tf.train.Int64List(value=[image_height]))),
'image/width': (tf.train.Feature(
int64_list=tf.train.Int64List(value=[image_width]))),
})
return tf.train.Example(features=tf.train.Features(feature=feature))
def create_segmentation_test_example(image_height: int, image_width: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment