Commit c3b9742b authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

1. Modify GenerateEmbeddingDataFn to consume key_value tuple to support file...

1. Modify GenerateEmbeddingDataFn to consume key_value tuple to support file formats with (key, value) entries.
2. Copy all fields form input tf examples before adding context fields.

PiperOrigin-RevId: 324629673
parent db87dd5b
......@@ -63,6 +63,15 @@ except ModuleNotFoundError:
pass
def add_keys(serialized_example):
key = hash(serialized_example)
return key, serialized_example
def drop_keys(key_value_tuple):
return key_value_tuple[1]
class GenerateEmbeddingDataFn(beam.DoFn):
"""Generates embedding data for camera trap images.
......@@ -97,11 +106,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
with self.session_lock:
self._detect_fn = tf.saved_model.load(self._model_dir)
def process(self, tfrecord_entry):
return self._run_inference_and_generate_embedding(tfrecord_entry)
def process(self, tfexample_key_value):
return self._run_inference_and_generate_embedding(tfexample_key_value)
def _run_inference_and_generate_embedding(self, tfrecord_entry):
input_example = tf.train.Example.FromString(tfrecord_entry)
def _run_inference_and_generate_embedding(self, tfexample_key_value):
key, tfexample = tfexample_key_value
input_example = tf.train.Example.FromString(tfexample)
# Convert date_captured datetime string to unix time integer and store
def get_date_captured(example):
......@@ -161,11 +171,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
(date_captured - datetime.datetime.fromtimestamp(0)).total_seconds())
example = tf.train.Example()
example.CopyFrom(input_example)
example.features.feature['image/unix_time'].float_list.value.extend(
[unix_time])
detections = self._detect_fn.signatures['serving_default'](
(tf.expand_dims(tf.convert_to_tensor(tfrecord_entry), 0)))
(tf.expand_dims(tf.convert_to_tensor(tfexample), 0)))
detection_features = detections['detection_features']
detection_boxes = detections['detection_boxes']
num_detections = detections['num_detections']
......@@ -230,60 +241,8 @@ class GenerateEmbeddingDataFn(beam.DoFn):
example.features.feature['image/embedding_count'].int64_list.value.append(
embedding_count)
# Add other essential example attributes
example.features.feature['image/encoded'].bytes_list.value.extend(
input_example.features.feature['image/encoded'].bytes_list.value)
example.features.feature['image/height'].int64_list.value.extend(
input_example.features.feature['image/height'].int64_list.value)
example.features.feature['image/width'].int64_list.value.extend(
input_example.features.feature['image/width'].int64_list.value)
example.features.feature['image/source_id'].bytes_list.value.extend(
input_example.features.feature['image/source_id'].bytes_list.value)
example.features.feature['image/location'].bytes_list.value.extend(
input_example.features.feature['image/location'].bytes_list.value)
example.features.feature['image/date_captured'].bytes_list.value.extend(
input_example.features.feature['image/date_captured'].bytes_list.value)
example.features.feature['image/class/text'].bytes_list.value.extend(
input_example.features.feature['image/class/text'].bytes_list.value)
example.features.feature['image/class/label'].int64_list.value.extend(
input_example.features.feature['image/class/label'].int64_list.value)
example.features.feature['image/seq_id'].bytes_list.value.extend(
input_example.features.feature['image/seq_id'].bytes_list.value)
example.features.feature['image/seq_num_frames'].int64_list.value.extend(
input_example.features.feature['image/seq_num_frames'].int64_list.value)
example.features.feature['image/seq_frame_num'].int64_list.value.extend(
input_example.features.feature['image/seq_frame_num'].int64_list.value)
example.features.feature['image/object/bbox/ymax'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/ymax'].float_list.value)
example.features.feature['image/object/bbox/ymin'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/ymin'].float_list.value)
example.features.feature['image/object/bbox/xmax'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/xmax'].float_list.value)
example.features.feature['image/object/bbox/xmin'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/xmin'].float_list.value)
example.features.feature[
'image/object/class/score'].float_list.value.extend(
input_example.features.feature[
'image/object/class/score'].float_list.value)
example.features.feature[
'image/object/class/label'].int64_list.value.extend(
input_example.features.feature[
'image/object/class/label'].int64_list.value)
example.features.feature[
'image/object/class/text'].bytes_list.value.extend(
input_example.features.feature[
'image/object/class/text'].bytes_list.value)
self._num_examples_processed.inc(1)
return [example]
return [(key, example)]
def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
......@@ -303,13 +262,14 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
"""
input_collection = (
pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
input_tfrecord,
coder=beam.coders.BytesCoder()))
input_tfrecord, coder=beam.coders.BytesCoder())
| 'AddKeys' >> beam.Map(add_keys))
output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo(
GenerateEmbeddingDataFn(model_dir, top_k_embedding_count,
bottom_k_embedding_count))
output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle()
_ = output_collection | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
_ = output_collection | 'DropKeys' >> beam.Map(
drop_keys) | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
output_tfrecord,
num_shards=num_shards,
coder=beam.coders.ProtoCoder(tf.train.Example))
......@@ -395,4 +355,3 @@ def main(argv=None, save_main_session=True):
if __name__ == '__main__':
main()
......@@ -258,8 +258,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual(tf.train.Example.FromString(
generated_example).features.feature['image/object/class/text']
.bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example)
output_example = output[0]
output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0][1]
self.assert_expected_example(output_example)
def test_generate_embedding_data_with_top_k_boxes(self):
......@@ -276,8 +276,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual(
tf.train.Example.FromString(generated_example).features
.feature['image/object/class/text'].bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example)
output_example = output[0]
output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0][1]
self.assert_expected_example(output_example, topk=True)
def test_generate_embedding_data_with_bottom_k_boxes(self):
......@@ -294,8 +294,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual(
tf.train.Example.FromString(generated_example).features
.feature['image/object/class/text'].bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example)
output_example = output[0]
output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0][1]
self.assert_expected_example(output_example, botk=True)
def test_beam_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment