Commit ca9cf75f authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 353289158
parent 9f3f80e2
......@@ -17,8 +17,6 @@
A decoder to decode string tensors containing serialized tensorflow.Example
protos for object detection.
"""
import csv
# Import libraries
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
......@@ -172,44 +170,3 @@ class TfExampleDecoder(decoder.Decoder):
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
})
return decoded_tensors
class TfExampleDecoderLabelMap(TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, label_map, include_mask=False, regenerate_source_id=False):
super(TfExampleDecoderLabelMap, self).__init__(
include_mask=include_mask, regenerate_source_id=regenerate_source_id)
self._keys_to_features.update({
'image/object/class/text': tf.io.VarLenFeature(tf.string),
})
name_to_id = self._process_label_map(label_map)
self._name_to_id_table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
keys=tf.constant(list(name_to_id.keys()), dtype=tf.string),
values=tf.constant(list(name_to_id.values()), dtype=tf.int64)),
default_value=-1)
def _process_label_map(self, label_map):
if label_map.endswith('.csv'):
name_to_id = self._process_csv(label_map)
else:
raise ValueError('The label map file is in incorrect format.')
return name_to_id
def _process_csv(self, label_map):
name_to_id = {}
with tf.io.gfile.GFile(label_map, 'r') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
if len(row) != 2:
raise ValueError('Each row of the csv label map file must be in '
'`id,name` format. length = {}'.format(len(row)))
id_index = int(row[0])
name = row[1]
name_to_id[name] = id_index
return name_to_id
def _decode_classes(self, parsed_tensors):
return self._name_to_id_table.lookup(
parsed_tensors['image/object/class/text'])
......@@ -94,16 +94,7 @@ class RetinaNetTask(base_task.Task):
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
parser = retinanet_input.Parser(
output_size=self.task_config.model.input_size[:2],
min_level=self.task_config.model.min_level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment