"vscode:/vscode.git/clone" did not exist on "27a02fd1a1ebe46de67f230216f66ccbdf0e91d2"
Unverified Commit 1e2ada24 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #2692 from tombstone/fix_example_decoder

temporarily change tf_example_decoder to not depend on BackupHandler.
parents 59b96e9a 64f0761b
...@@ -113,24 +113,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -113,24 +113,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder.ItemHandlerCallback( slim_example_decoder.ItemHandlerCallback(
['image/object/mask', 'image/height', 'image/width'], ['image/object/mask', 'image/height', 'image/width'],
self._reshape_instance_masks)) self._reshape_instance_masks))
if label_map_proto_file: # TODO: Add label_handler that decodes from 'image/object/class/text'
label_map = label_map_util.get_label_map_dict(label_map_proto_file, # primarily after the recent tf.contrib.slim changes make into a release
use_display_name) # supported by cloudml.
# We use a default_value of -1, but we expect all labels to be contained label_handler = slim_example_decoder.Tensor('image/object/class/label')
# in the label map.
table = tf.contrib.lookup.HashTable(
initializer=tf.contrib.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(label_map.keys())),
values=tf.constant(list(label_map.values()), dtype=tf.int64)),
default_value=-1)
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
label_handler = slim_example_decoder.BackupHandler(
slim_example_decoder.LookupTensor(
'image/object/class/text', table, default_value=''),
slim_example_decoder.Tensor('image/object/class/label'))
else:
label_handler = slim_example_decoder.Tensor('image/object/class/label')
self.items_to_handlers[ self.items_to_handlers[
fields.InputDataFields.groundtruth_classes] = label_handler fields.InputDataFields.groundtruth_classes] = label_handler
......
...@@ -168,48 +168,6 @@ class TfExampleDecoderTest(tf.test.TestCase): ...@@ -168,48 +168,6 @@ class TfExampleDecoderTest(tf.test.TestCase):
self.assertAllEqual(bbox_classes, self.assertAllEqual(bbox_classes,
tensor_dict[fields.InputDataFields.groundtruth_classes]) tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectLabelWithMapping(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor)
bbox_classes_text = ['cat', 'dog']
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
self._BytesFeature(encoded_jpeg),
'image/format':
self._BytesFeature('jpeg'),
'image/object/class/text':
self._BytesFeature(bbox_classes_text),
})).SerializeToString()
label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
example_decoder = tf_example_decoder.TfExampleDecoder(
label_map_proto_file=label_map_path)
tensor_dict = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual((tensor_dict[fields.InputDataFields.groundtruth_classes]
.get_shape().as_list()), [None])
with self.test_session() as sess:
sess.run(tf.tables_initializer())
tensor_dict = sess.run(tensor_dict)
self.assertAllEqual([3, 1],
tensor_dict[fields.InputDataFields.groundtruth_classes])
def testDecodeObjectArea(self): def testDecodeObjectArea(self):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg = self._EncodeImage(image_tensor) encoded_jpeg = self._EncodeImage(image_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment