Commit c3b9742b authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

1. Modify GenerateEmbeddingDataFn to consume key_value tuple to support file...

1. Modify GenerateEmbeddingDataFn to consume key_value tuple to support file formats with (key, value) entries.
2. Copy all fields form input tf examples before adding context fields.

PiperOrigin-RevId: 324629673
parent db87dd5b
...@@ -63,6 +63,15 @@ except ModuleNotFoundError: ...@@ -63,6 +63,15 @@ except ModuleNotFoundError:
pass pass
def add_keys(serialized_example):
key = hash(serialized_example)
return key, serialized_example
def drop_keys(key_value_tuple):
return key_value_tuple[1]
class GenerateEmbeddingDataFn(beam.DoFn): class GenerateEmbeddingDataFn(beam.DoFn):
"""Generates embedding data for camera trap images. """Generates embedding data for camera trap images.
...@@ -97,11 +106,12 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -97,11 +106,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
with self.session_lock: with self.session_lock:
self._detect_fn = tf.saved_model.load(self._model_dir) self._detect_fn = tf.saved_model.load(self._model_dir)
def process(self, tfrecord_entry): def process(self, tfexample_key_value):
return self._run_inference_and_generate_embedding(tfrecord_entry) return self._run_inference_and_generate_embedding(tfexample_key_value)
def _run_inference_and_generate_embedding(self, tfrecord_entry): def _run_inference_and_generate_embedding(self, tfexample_key_value):
input_example = tf.train.Example.FromString(tfrecord_entry) key, tfexample = tfexample_key_value
input_example = tf.train.Example.FromString(tfexample)
# Convert date_captured datetime string to unix time integer and store # Convert date_captured datetime string to unix time integer and store
def get_date_captured(example): def get_date_captured(example):
...@@ -161,11 +171,12 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -161,11 +171,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
(date_captured - datetime.datetime.fromtimestamp(0)).total_seconds()) (date_captured - datetime.datetime.fromtimestamp(0)).total_seconds())
example = tf.train.Example() example = tf.train.Example()
example.CopyFrom(input_example)
example.features.feature['image/unix_time'].float_list.value.extend( example.features.feature['image/unix_time'].float_list.value.extend(
[unix_time]) [unix_time])
detections = self._detect_fn.signatures['serving_default']( detections = self._detect_fn.signatures['serving_default'](
(tf.expand_dims(tf.convert_to_tensor(tfrecord_entry), 0))) (tf.expand_dims(tf.convert_to_tensor(tfexample), 0)))
detection_features = detections['detection_features'] detection_features = detections['detection_features']
detection_boxes = detections['detection_boxes'] detection_boxes = detections['detection_boxes']
num_detections = detections['num_detections'] num_detections = detections['num_detections']
...@@ -230,60 +241,8 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -230,60 +241,8 @@ class GenerateEmbeddingDataFn(beam.DoFn):
example.features.feature['image/embedding_count'].int64_list.value.append( example.features.feature['image/embedding_count'].int64_list.value.append(
embedding_count) embedding_count)
# Add other essential example attributes
example.features.feature['image/encoded'].bytes_list.value.extend(
input_example.features.feature['image/encoded'].bytes_list.value)
example.features.feature['image/height'].int64_list.value.extend(
input_example.features.feature['image/height'].int64_list.value)
example.features.feature['image/width'].int64_list.value.extend(
input_example.features.feature['image/width'].int64_list.value)
example.features.feature['image/source_id'].bytes_list.value.extend(
input_example.features.feature['image/source_id'].bytes_list.value)
example.features.feature['image/location'].bytes_list.value.extend(
input_example.features.feature['image/location'].bytes_list.value)
example.features.feature['image/date_captured'].bytes_list.value.extend(
input_example.features.feature['image/date_captured'].bytes_list.value)
example.features.feature['image/class/text'].bytes_list.value.extend(
input_example.features.feature['image/class/text'].bytes_list.value)
example.features.feature['image/class/label'].int64_list.value.extend(
input_example.features.feature['image/class/label'].int64_list.value)
example.features.feature['image/seq_id'].bytes_list.value.extend(
input_example.features.feature['image/seq_id'].bytes_list.value)
example.features.feature['image/seq_num_frames'].int64_list.value.extend(
input_example.features.feature['image/seq_num_frames'].int64_list.value)
example.features.feature['image/seq_frame_num'].int64_list.value.extend(
input_example.features.feature['image/seq_frame_num'].int64_list.value)
example.features.feature['image/object/bbox/ymax'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/ymax'].float_list.value)
example.features.feature['image/object/bbox/ymin'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/ymin'].float_list.value)
example.features.feature['image/object/bbox/xmax'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/xmax'].float_list.value)
example.features.feature['image/object/bbox/xmin'].float_list.value.extend(
input_example.features.feature[
'image/object/bbox/xmin'].float_list.value)
example.features.feature[
'image/object/class/score'].float_list.value.extend(
input_example.features.feature[
'image/object/class/score'].float_list.value)
example.features.feature[
'image/object/class/label'].int64_list.value.extend(
input_example.features.feature[
'image/object/class/label'].int64_list.value)
example.features.feature[
'image/object/class/text'].bytes_list.value.extend(
input_example.features.feature[
'image/object/class/text'].bytes_list.value)
self._num_examples_processed.inc(1) self._num_examples_processed.inc(1)
return [example] return [(key, example)]
def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir, def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
...@@ -303,16 +262,17 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir, ...@@ -303,16 +262,17 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
""" """
input_collection = ( input_collection = (
pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord( pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
input_tfrecord, input_tfrecord, coder=beam.coders.BytesCoder())
coder=beam.coders.BytesCoder())) | 'AddKeys' >> beam.Map(add_keys))
output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo( output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo(
GenerateEmbeddingDataFn(model_dir, top_k_embedding_count, GenerateEmbeddingDataFn(model_dir, top_k_embedding_count,
bottom_k_embedding_count)) bottom_k_embedding_count))
output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle() output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle()
_ = output_collection | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord( _ = output_collection | 'DropKeys' >> beam.Map(
output_tfrecord, drop_keys) | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
num_shards=num_shards, output_tfrecord,
coder=beam.coders.ProtoCoder(tf.train.Example)) num_shards=num_shards,
coder=beam.coders.ProtoCoder(tf.train.Example))
def parse_args(argv): def parse_args(argv):
...@@ -395,4 +355,3 @@ def main(argv=None, save_main_session=True): ...@@ -395,4 +355,3 @@ def main(argv=None, save_main_session=True):
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -258,8 +258,8 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -258,8 +258,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual(tf.train.Example.FromString( self.assertAllEqual(tf.train.Example.FromString(
generated_example).features.feature['image/object/class/text'] generated_example).features.feature['image/object/class/text']
.bytes_list.value, [b'hyena']) .bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example) output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0] output_example = output[0][1]
self.assert_expected_example(output_example) self.assert_expected_example(output_example)
def test_generate_embedding_data_with_top_k_boxes(self): def test_generate_embedding_data_with_top_k_boxes(self):
...@@ -276,8 +276,8 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -276,8 +276,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
tf.train.Example.FromString(generated_example).features tf.train.Example.FromString(generated_example).features
.feature['image/object/class/text'].bytes_list.value, [b'hyena']) .feature['image/object/class/text'].bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example) output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0] output_example = output[0][1]
self.assert_expected_example(output_example, topk=True) self.assert_expected_example(output_example, topk=True)
def test_generate_embedding_data_with_bottom_k_boxes(self): def test_generate_embedding_data_with_bottom_k_boxes(self):
...@@ -294,8 +294,8 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -294,8 +294,8 @@ class GenerateEmbeddingData(tf.test.TestCase):
self.assertAllEqual( self.assertAllEqual(
tf.train.Example.FromString(generated_example).features tf.train.Example.FromString(generated_example).features
.feature['image/object/class/text'].bytes_list.value, [b'hyena']) .feature['image/object/class/text'].bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example) output = inference_fn.process(('dummy_key', generated_example))
output_example = output[0] output_example = output[0][1]
self.assert_expected_example(output_example, botk=True) self.assert_expected_example(output_example, botk=True)
def test_beam_pipeline(self): def test_beam_pipeline(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment