Commit 101282ad authored by Sara Beery's avatar Sara Beery Committed by TF Object Detection Team
Browse files

Updating Context R-CNN dataset tools to fix a bug in context feature bank...

Updating Context R-CNN dataset tools to fix a bug in context feature bank building logic when the number of context examples in the time horizon is greater than the specified maximum, and adding capabilities to track and save the number of embeddings stored per image.

PiperOrigin-RevId: 345548385
parent 1e205552
......@@ -294,20 +294,46 @@ class SortGroupedDataFn(beam.DoFn):
sorted_example_list = sorted(example_list, key=sorting_fn)
num_embeddings = 0
for example in sorted_example_list:
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
self._num_examples_processed.inc(1)
if len(sorted_example_list) > self._max_num_elements_in_context_features:
# To handle cases where there are more context embeddings within
# the time horizon than the specified maximum, we split the context group
# into subsets sequentially in time, with each subset having the maximum
# number of context embeddings except the final one, which holds the
# remainder.
if num_embeddings > self._max_num_elements_in_context_features:
leftovers = sorted_example_list
output_list = []
count = 0
self._too_many_elements.inc(1)
while len(leftovers) > self._max_num_elements_in_context_features:
num_embeddings = 0
max_idx = 0
for idx, example in enumerate(leftovers):
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
if num_embeddings <= self._max_num_elements_in_context_features:
max_idx = idx
while num_embeddings > self._max_num_elements_in_context_features:
self._split_elements.inc(1)
new_key = key + six.ensure_binary('_' + str(count))
new_list = leftovers[:self._max_num_elements_in_context_features]
new_list = leftovers[:max_idx]
output_list.append((new_key, new_list))
leftovers = leftovers[:self._max_num_elements_in_context_features]
leftovers = leftovers[max_idx:]
count += 1
num_embeddings = 0
max_idx = 0
for idx, example in enumerate(leftovers):
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
if num_embeddings <= self._max_num_elements_in_context_features:
max_idx = idx
new_key = key + six.ensure_binary('_' + str(count))
output_list.append((new_key, leftovers))
else:
output_list = [(key, sorted_example_list)]
......@@ -454,11 +480,14 @@ class GenerateContextFn(beam.DoFn):
example_embedding = list(example.features.feature[
'image/embedding'].float_list.value)
context_features.extend(example_embedding)
num_embeddings = example.features.feature[
'image/embedding_count'].int64_list.value[0]
example_image_id = example.features.feature[
'image/source_id'].bytes_list.value[0]
for _ in range(num_embeddings):
example.features.feature[
'context_features_idx'].int64_list.value.append(count)
count += 1
example_image_id = example.features.feature[
'image/source_id'].bytes_list.value[0]
context_features_image_id_list.append(example_image_id)
if not example_embedding:
......
......@@ -98,7 +98,8 @@ class GenerateContextDataTest(tf.test.TestCase):
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 0, 0)))),
'image/embedding': FloatListFeature([0.1, 0.2, 0.3]),
'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3)
'image/embedding_length': Int64Feature(3),
'image/embedding_count': Int64Feature(1)
}))
......@@ -127,7 +128,8 @@ class GenerateContextDataTest(tf.test.TestCase):
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 1, 0)))),
'image/embedding': FloatListFeature([0.4, 0.5, 0.6]),
'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3)
'image/embedding_length': Int64Feature(3),
'image/embedding_count': Int64Feature(1)
}))
return example.SerializeToString()
......
......@@ -142,13 +142,14 @@ class GenerateEmbeddingDataFn(beam.DoFn):
session_lock = threading.Lock()
def __init__(self, model_dir, top_k_embedding_count,
bottom_k_embedding_count):
bottom_k_embedding_count, embedding_type='final_box_features'):
"""Initialization function.
Args:
model_dir: A directory containing saved model.
top_k_embedding_count: the number of high-confidence embeddings to store
bottom_k_embedding_count: the number of low-confidence embeddings to store
embedding_type: One of 'final_box_features', 'rpn_box_features'
"""
self._model_dir = model_dir
self._session = None
......@@ -156,6 +157,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
'embedding_data_generation', 'num_tf_examples_processed')
self._top_k_embedding_count = top_k_embedding_count
self._bottom_k_embedding_count = bottom_k_embedding_count
self._embedding_type = embedding_type
def setup(self):
self._load_inference_model()
......@@ -188,7 +190,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
detections = self._detect_fn.signatures['serving_default'](
(tf.expand_dims(tf.convert_to_tensor(tfexample), 0)))
if self._embedding_type == 'final_box_features':
detection_features = detections['detection_features']
elif self._embedding_type == 'rpn_box_features':
detection_features = detections['cropped_rpn_box_features']
else:
raise ValueError('embedding type not supported')
detection_boxes = detections['detection_boxes']
num_detections = detections['num_detections']
detection_scores = detections['detection_scores']
......@@ -245,7 +252,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
top_k_embedding_count, bottom_k_embedding_count,
num_shards):
num_shards, embedding_type):
"""Returns a beam pipeline to run object detection inference.
Args:
......@@ -257,6 +264,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
top_k_embedding_count: The number of high-confidence embeddings to store.
bottom_k_embedding_count: The number of low-confidence embeddings to store.
num_shards: The number of output shards.
embedding_type: Which features to embed.
"""
input_collection = (
pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
......@@ -264,7 +272,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
| 'AddKeys' >> beam.Map(add_keys))
output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo(
GenerateEmbeddingDataFn(model_dir, top_k_embedding_count,
bottom_k_embedding_count))
bottom_k_embedding_count, embedding_type))
output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle()
_ = output_collection | 'DropKeys' >> beam.Map(
drop_keys) | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
......@@ -315,6 +323,12 @@ def parse_args(argv):
dest='num_shards',
default=0,
help='Number of output shards.')
parser.add_argument(
'--embedding_type',
dest='embedding_type',
default='final_box_features',
help='What features to embed, supports `final_box_features`, '
'`rpn_box_features`.')
beam_args, pipeline_args = parser.parse_known_args(argv)
return beam_args, pipeline_args
......@@ -346,7 +360,8 @@ def main(argv=None, save_main_session=True):
args.embedding_model_dir,
args.top_k_embedding_count,
args.bottom_k_embedding_count,
args.num_shards)
args.num_shards,
args.embedding_type)
p.run()
......
......@@ -307,12 +307,14 @@ class GenerateEmbeddingData(tf.test.TestCase):
top_k_embedding_count = 1
bottom_k_embedding_count = 0
num_shards = 1
embedding_type = 'final_box_features'
pipeline_options = beam.options.pipeline_options.PipelineOptions(
runner='DirectRunner')
p = beam.Pipeline(options=pipeline_options)
generate_embedding_data.construct_pipeline(
p, input_tfrecord, output_tfrecord, saved_model_path,
top_k_embedding_count, bottom_k_embedding_count, num_shards)
top_k_embedding_count, bottom_k_embedding_count, num_shards,
embedding_type)
p.run()
filenames = tf.io.gfile.glob(
output_tfrecord + '-?????-of-?????')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment