"tests/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "59ad273784c865f8bb6bbd4bfedbd6a24fdcfc73"
Commit ce417629 authored by Vivek Rathod's avatar Vivek Rathod
Browse files

update tf example decoder and fix the breaking input reader builder tests.

parent 141ed951
...@@ -14,6 +14,7 @@ py_library( ...@@ -14,6 +14,7 @@ py_library(
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:data_decoder", "//tensorflow_models/object_detection/core:data_decoder",
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/utils:label_map_util",
], ],
) )
......
...@@ -22,6 +22,7 @@ import tensorflow as tf ...@@ -22,6 +22,7 @@ import tensorflow as tf
from object_detection.core import data_decoder from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.utils import label_map_util
slim_example_decoder = tf.contrib.slim.tfexample_decoder slim_example_decoder = tf.contrib.slim.tfexample_decoder
...@@ -29,28 +30,59 @@ slim_example_decoder = tf.contrib.slim.tfexample_decoder ...@@ -29,28 +30,59 @@ slim_example_decoder = tf.contrib.slim.tfexample_decoder
class TfExampleDecoder(data_decoder.DataDecoder): class TfExampleDecoder(data_decoder.DataDecoder):
"""Tensorflow Example proto decoder.""" """Tensorflow Example proto decoder."""
def __init__(self): def __init__(self,
"""Constructor sets keys_to_features and items_to_handlers.""" load_instance_masks=False,
label_map_proto_file=None,
use_display_name=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
load_instance_masks: whether or not to load and handle instance masks.
label_map_proto_file: a file path to a
object_detection.protos.StringIntLabelMap proto. If provided, then the
mapped IDs of 'image/object/class/text' will take precedence over the
existing 'image/object/class/label' ID. Also, if provided, it is
assumed that 'image/object/class/text' will be in the data.
use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is
provided.
"""
self.keys_to_features = { self.keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/encoded':
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), tf.FixedLenFeature((), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format':
'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''), tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''), 'image/filename':
'image/height': tf.FixedLenFeature((), tf.int64, 1), tf.FixedLenFeature((), tf.string, default_value=''),
'image/width': tf.FixedLenFeature((), tf.int64, 1), 'image/key/sha256':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/source_id':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/height':
tf.FixedLenFeature((), tf.int64, 1),
'image/width':
tf.FixedLenFeature((), tf.int64, 1),
# Object boxes and classes. # Object boxes and classes.
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32), 'image/object/bbox/xmin':
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32), tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32), 'image/object/bbox/xmax':
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32), tf.VarLenFeature(tf.float32),
'image/object/class/label': tf.VarLenFeature(tf.int64), 'image/object/bbox/ymin':
'image/object/area': tf.VarLenFeature(tf.float32), tf.VarLenFeature(tf.float32),
'image/object/is_crowd': tf.VarLenFeature(tf.int64), 'image/object/bbox/ymax':
'image/object/difficult': tf.VarLenFeature(tf.int64), tf.VarLenFeature(tf.float32),
# Instance masks and classes. 'image/object/class/label':
'image/segmentation/object': tf.VarLenFeature(tf.int64), tf.VarLenFeature(tf.int64),
'image/segmentation/object/class': tf.VarLenFeature(tf.int64) 'image/object/class/text':
tf.VarLenFeature(tf.string),
'image/object/area':
tf.VarLenFeature(tf.float32),
'image/object/is_crowd':
tf.VarLenFeature(tf.int64),
'image/object/difficult':
tf.VarLenFeature(tf.int64),
'image/object/group_of':
tf.VarLenFeature(tf.int64),
} }
self.items_to_handlers = { self.items_to_handlers = {
fields.InputDataFields.image: slim_example_decoder.Image( fields.InputDataFields.image: slim_example_decoder.Image(
...@@ -65,22 +97,42 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -65,22 +97,42 @@ class TfExampleDecoder(data_decoder.DataDecoder):
fields.InputDataFields.groundtruth_boxes: ( fields.InputDataFields.groundtruth_boxes: (
slim_example_decoder.BoundingBox( slim_example_decoder.BoundingBox(
['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')), ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')),
fields.InputDataFields.groundtruth_classes: (
slim_example_decoder.Tensor('image/object/class/label')),
fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor( fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor(
'image/object/area'), 'image/object/area'),
fields.InputDataFields.groundtruth_is_crowd: ( fields.InputDataFields.groundtruth_is_crowd: (
slim_example_decoder.Tensor('image/object/is_crowd')), slim_example_decoder.Tensor('image/object/is_crowd')),
fields.InputDataFields.groundtruth_difficult: ( fields.InputDataFields.groundtruth_difficult: (
slim_example_decoder.Tensor('image/object/difficult')), slim_example_decoder.Tensor('image/object/difficult')),
# Instance masks and classes. fields.InputDataFields.groundtruth_group_of: (
fields.InputDataFields.groundtruth_instance_masks: ( slim_example_decoder.Tensor('image/object/group_of'))
slim_example_decoder.ItemHandlerCallback(
['image/segmentation/object', 'image/height', 'image/width'],
self._reshape_instance_masks)),
fields.InputDataFields.groundtruth_instance_classes: (
slim_example_decoder.Tensor('image/segmentation/object/class')),
} }
if load_instance_masks:
self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.float32)
self.items_to_handlers[
fields.InputDataFields.groundtruth_instance_masks] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks))
if label_map_proto_file:
label_map = label_map_util.get_label_map_dict(label_map_proto_file,
use_display_name)
# We use a default_value of -1, but we expect all labels to be contained
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
label_handler = slim_example_decoder.BackupHandler(
slim_example_decoder.LookupTensor(
'image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler
def decode(self, tf_example_string_tensor): def decode(self, tf_example_string_tensor):
"""Decodes serialized tensorflow example and returns a tensor dictionary. """Decodes serialized tensorflow example and returns a tensor dictionary.
...@@ -106,14 +158,14 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -106,14 +158,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
[None] containing containing object mask area in pixel squared. [None] containing containing object mask area in pixel squared.
fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
[None] indicating if the boxes enclose a crowd. [None] indicating if the boxes enclose a crowd.
Optional:
fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
[None] indicating if the boxes represent `difficult` instances. [None] indicating if the boxes represent `difficult` instances.
fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
[None] indicating if the boxes represent `group_of` instances.
fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of fields.InputDataFields.groundtruth_instance_masks - 3D int64 tensor of
shape [None, None, None] containing instance masks. shape [None, None, None] containing instance masks.
fields.InputDataFields.groundtruth_instance_classes - 1D int64 tensor
of shape [None] containing classes for the instance masks.
""" """
serialized_example = tf.reshape(tf_example_string_tensor, shape=[]) serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features, decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
self.items_to_handlers) self.items_to_handlers)
...@@ -135,13 +187,14 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -135,13 +187,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keys_to_tensors: a dictionary from keys to tensors. keys_to_tensors: a dictionary from keys to tensors.
Returns: Returns:
A 3-D boolean tensor of shape [num_instances, height, width]. A 3-D float tensor of shape [num_instances, height, width] with values
in {0, 1}.
""" """
masks = keys_to_tensors['image/segmentation/object']
if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
height = keys_to_tensors['image/height'] height = keys_to_tensors['image/height']
width = keys_to_tensors['image/width'] width = keys_to_tensors['image/width']
to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32) to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
masks = keys_to_tensors['image/object/mask']
return tf.cast(tf.reshape(masks, to_shape), tf.bool) if isinstance(masks, tf.SparseTensor):
masks = tf.sparse_tensor_to_dense(masks)
masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
return tf.cast(masks, tf.float32)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tests for object_detection.data_decoders.tf_example_decoder.""" """Tests for object_detection.data_decoders.tf_example_decoder."""
import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -51,6 +52,8 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -51,6 +52,8 @@ class TfExampleDecoderTest(tf.test.TestCase):
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _BytesFeature(self, value): def _BytesFeature(self, value):
if isinstance(value, list):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def testDecodeJpegImage(self): def testDecodeJpegImage(self):
...@@ -165,6 +168,48 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -165,6 +168,48 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual(bbox_classes, self.assertAllEqual(bbox_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes]) tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectLabelWithMapping(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
bbox_classes_text = ['cat', 'dog']
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/format':
self._BytesFeature('jpeg'),
'image/object/class/text':
self._BytesFeature(bbox_classes_text),
})).SerializeToString()
label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_path)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
.get_shape().as_list()), [None])
with self.test_session() as sess:
sess.run(tf.tables_initializer())
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectArea(self): def testDecodeObjectArea(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor) encoded_jpeg = self._EncodeImage(image_tensor)
...@@ -232,6 +277,30 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -232,6 +277,30 @@ class TfExampleDecoderTest(tf.test.TestCase):
tensor_dict[ tensor_dict[
fields.InputDataFields.groundtruth_difficult]) fields.InputDataFields.groundtruth_difficult])
def testDecodeObjectGroupOf(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
object_group_of = [0, 1]
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/object/group_of': self._Int64Feature(object_group_of),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[
fields.InputDataFields.groundtruth_group_of].get_shape().as_list()),
[None])
with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual(
[bool(item) for item in object_group_of],
tensor_dict[fields.InputDataFields.groundtruth_group_of])
def testDecodeInstanceSegmentation(self): def testDecodeInstanceSegmentation(self):
num_instances = 4 num_instances = 4
image_height = 5 image_height = 5
...@@ -244,13 +313,14 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -244,13 +313,14 @@ class TfExampleDecoderTest(tf.test.TestCase):
encoded_jpeg = self._EncodeImage(image_tensor) encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks. # Randomly generate instance segmentation masks.
instance_segmentation = ( instance_masks = (
np.random.randint(2, size=(num_instances, np.random.randint(2, size=(num_instances,
image_height, image_height,
image_width)).astype(np.int64)) image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance. # Randomly generate class labels for each instance.
instance_segmentation_classes = np.random.randint( object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64) 100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={ example = tf.train.Example(features=tf.train.Features(feature={
...@@ -258,11 +328,11 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -258,11 +328,11 @@ class TfExampleDecoderTest(tf.test.TestCase):
'image/format': self._BytesFeature('jpeg'), 'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]), 'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]), 'image/width': self._Int64Feature([image_width]),
'image/segmentation/object': self._Int64Feature( 'image/object/mask': self._FloatFeature(instance_masks_flattened),
instance_segmentation.flatten()), 'image/object/class/label': self._Int64Feature(
'image/segmentation/object/class': self._Int64Feature( object_classes)})).SerializeToString()
instance_segmentation_classes)})).SerializeToString() example_decoder = tf_example_decoder.TfExampleDecoder(
example_decoder = tf_example_decoder.TfExampleDecoder() load_instance_masks=True)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example)) tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(( self.assertAllEqual((
...@@ -270,18 +340,52 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -270,18 +340,52 @@ class TfExampleDecoderTest(tf.test.TestCase):
get_shape().as_list()), [None, None, None]) get_shape().as_list()), [None, None, None])
self.assertAllEqual(( self.assertAllEqual((
tensor_dict[fields.InputDataFields.groundtruth_instance_classes]. tensor_dict[fields.InputDataFields.groundtruth_classes].
get_shape().as_list()), [None]) get_shape().as_list()), [None])
with self.test_session() as sess: with self.test_session() as sess:
tensor_dict = sess.run(tensor_dict) tensor_dict = sess.run(tensor_dict)
self.assertAllEqual( self.assertAllEqual(
instance_segmentation.astype(np.bool), instance_masks.astype(np.float32),
tensor_dict[fields.InputDataFields.groundtruth_instance_masks]) tensor_dict[fields.InputDataFields.groundtruth_instance_masks])
self.assertAllEqual( self.assertAllEqual(
instance_segmentation_classes, object_classes,
tensor_dict[fields.InputDataFields.groundtruth_instance_classes]) tensor_dict[fields.InputDataFields.groundtruth_classes])
def testInstancesNotAvailableByDefault(self):
num_instances = 4
image_height = 5
image_width = 3
# Randomly generate image.
image_tensor = np.random.randint(255, size=(image_height,
image_width,
3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
# Randomly generate instance segmentation masks.
instance_masks = (
np.random.randint(2, size=(num_instances,
image_height,
image_width)).astype(np.float32))
instance_masks_flattened = np.reshape(instance_masks, [-1])
# Randomly generate class labels for each instance.
object_classes = np.random.randint(
100, size=(num_instances)).astype(np.int64)
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': self._BytesFeature(encoded_jpeg),
'image/format': self._BytesFeature('jpeg'),
'image/height': self._Int64Feature([image_height]),
'image/width': self._Int64Feature([image_width]),
'image/object/mask': self._FloatFeature(instance_masks_flattened),
'image/object/class/label': self._Int64Feature(
object_classes)})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder()
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
not in tensor_dict)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment