Commit 101282ad authored by Sara Beery's avatar Sara Beery Committed by TF Object Detection Team
Browse files

Updating Context R-CNN dataset tools to fix a bug in context feature bank...

Updating Context R-CNN dataset tools to fix a bug in context feature bank building logic when the number of context examples in the time horizon is greater than the specified maximum, and adding capabilities to track and save the number of embeddings stored per image.

PiperOrigin-RevId: 345548385
parent 1e205552
...@@ -294,20 +294,46 @@ class SortGroupedDataFn(beam.DoFn): ...@@ -294,20 +294,46 @@ class SortGroupedDataFn(beam.DoFn):
sorted_example_list = sorted(example_list, key=sorting_fn) sorted_example_list = sorted(example_list, key=sorting_fn)
num_embeddings = 0
for example in sorted_example_list:
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
self._num_examples_processed.inc(1) self._num_examples_processed.inc(1)
if len(sorted_example_list) > self._max_num_elements_in_context_features: # To handle cases where there are more context embeddings within
# the time horizon than the specified maximum, we split the context group
# into subsets sequentially in time, with each subset having the maximum
# number of context embeddings except the final one, which holds the
# remainder.
if num_embeddings > self._max_num_elements_in_context_features:
leftovers = sorted_example_list leftovers = sorted_example_list
output_list = [] output_list = []
count = 0 count = 0
self._too_many_elements.inc(1) self._too_many_elements.inc(1)
while len(leftovers) > self._max_num_elements_in_context_features: num_embeddings = 0
max_idx = 0
for idx, example in enumerate(leftovers):
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
if num_embeddings <= self._max_num_elements_in_context_features:
max_idx = idx
while num_embeddings > self._max_num_elements_in_context_features:
self._split_elements.inc(1) self._split_elements.inc(1)
new_key = key + six.ensure_binary('_' + str(count)) new_key = key + six.ensure_binary('_' + str(count))
new_list = leftovers[:self._max_num_elements_in_context_features] new_list = leftovers[:max_idx]
output_list.append((new_key, new_list)) output_list.append((new_key, new_list))
leftovers = leftovers[:self._max_num_elements_in_context_features] leftovers = leftovers[max_idx:]
count += 1 count += 1
num_embeddings = 0
max_idx = 0
for idx, example in enumerate(leftovers):
num_embeddings += example.features.feature[
'image/embedding_count'].int64_list.value[0]
if num_embeddings <= self._max_num_elements_in_context_features:
max_idx = idx
new_key = key + six.ensure_binary('_' + str(count))
output_list.append((new_key, leftovers))
else: else:
output_list = [(key, sorted_example_list)] output_list = [(key, sorted_example_list)]
...@@ -454,12 +480,15 @@ class GenerateContextFn(beam.DoFn): ...@@ -454,12 +480,15 @@ class GenerateContextFn(beam.DoFn):
example_embedding = list(example.features.feature[ example_embedding = list(example.features.feature[
'image/embedding'].float_list.value) 'image/embedding'].float_list.value)
context_features.extend(example_embedding) context_features.extend(example_embedding)
example.features.feature[ num_embeddings = example.features.feature[
'context_features_idx'].int64_list.value.append(count) 'image/embedding_count'].int64_list.value[0]
count += 1
example_image_id = example.features.feature[ example_image_id = example.features.feature[
'image/source_id'].bytes_list.value[0] 'image/source_id'].bytes_list.value[0]
context_features_image_id_list.append(example_image_id) for _ in range(num_embeddings):
example.features.feature[
'context_features_idx'].int64_list.value.append(count)
count += 1
context_features_image_id_list.append(example_image_id)
if not example_embedding: if not example_embedding:
example_embedding.append(np.zeros(self._context_feature_length)) example_embedding.append(np.zeros(self._context_feature_length))
......
...@@ -98,7 +98,8 @@ class GenerateContextDataTest(tf.test.TestCase): ...@@ -98,7 +98,8 @@ class GenerateContextDataTest(tf.test.TestCase):
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 0, 0)))), six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 0, 0)))),
'image/embedding': FloatListFeature([0.1, 0.2, 0.3]), 'image/embedding': FloatListFeature([0.1, 0.2, 0.3]),
'image/embedding_score': FloatListFeature([0.9]), 'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3) 'image/embedding_length': Int64Feature(3),
'image/embedding_count': Int64Feature(1)
})) }))
...@@ -127,7 +128,8 @@ class GenerateContextDataTest(tf.test.TestCase): ...@@ -127,7 +128,8 @@ class GenerateContextDataTest(tf.test.TestCase):
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 1, 0)))), six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 1, 0)))),
'image/embedding': FloatListFeature([0.4, 0.5, 0.6]), 'image/embedding': FloatListFeature([0.4, 0.5, 0.6]),
'image/embedding_score': FloatListFeature([0.9]), 'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3) 'image/embedding_length': Int64Feature(3),
'image/embedding_count': Int64Feature(1)
})) }))
return example.SerializeToString() return example.SerializeToString()
......
...@@ -142,13 +142,14 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -142,13 +142,14 @@ class GenerateEmbeddingDataFn(beam.DoFn):
session_lock = threading.Lock() session_lock = threading.Lock()
def __init__(self, model_dir, top_k_embedding_count, def __init__(self, model_dir, top_k_embedding_count,
bottom_k_embedding_count): bottom_k_embedding_count, embedding_type='final_box_features'):
"""Initialization function. """Initialization function.
Args: Args:
model_dir: A directory containing saved model. model_dir: A directory containing saved model.
top_k_embedding_count: the number of high-confidence embeddings to store top_k_embedding_count: the number of high-confidence embeddings to store
bottom_k_embedding_count: the number of low-confidence embeddings to store bottom_k_embedding_count: the number of low-confidence embeddings to store
embedding_type: One of 'final_box_features', 'rpn_box_features'
""" """
self._model_dir = model_dir self._model_dir = model_dir
self._session = None self._session = None
...@@ -156,6 +157,7 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -156,6 +157,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
'embedding_data_generation', 'num_tf_examples_processed') 'embedding_data_generation', 'num_tf_examples_processed')
self._top_k_embedding_count = top_k_embedding_count self._top_k_embedding_count = top_k_embedding_count
self._bottom_k_embedding_count = bottom_k_embedding_count self._bottom_k_embedding_count = bottom_k_embedding_count
self._embedding_type = embedding_type
def setup(self): def setup(self):
self._load_inference_model() self._load_inference_model()
...@@ -188,7 +190,12 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -188,7 +190,12 @@ class GenerateEmbeddingDataFn(beam.DoFn):
detections = self._detect_fn.signatures['serving_default']( detections = self._detect_fn.signatures['serving_default'](
(tf.expand_dims(tf.convert_to_tensor(tfexample), 0))) (tf.expand_dims(tf.convert_to_tensor(tfexample), 0)))
detection_features = detections['detection_features'] if self._embedding_type == 'final_box_features':
detection_features = detections['detection_features']
elif self._embedding_type == 'rpn_box_features':
detection_features = detections['cropped_rpn_box_features']
else:
raise ValueError('embedding type not supported')
detection_boxes = detections['detection_boxes'] detection_boxes = detections['detection_boxes']
num_detections = detections['num_detections'] num_detections = detections['num_detections']
detection_scores = detections['detection_scores'] detection_scores = detections['detection_scores']
...@@ -245,7 +252,7 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -245,7 +252,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir, def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
top_k_embedding_count, bottom_k_embedding_count, top_k_embedding_count, bottom_k_embedding_count,
num_shards): num_shards, embedding_type):
"""Returns a beam pipeline to run object detection inference. """Returns a beam pipeline to run object detection inference.
Args: Args:
...@@ -257,6 +264,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir, ...@@ -257,6 +264,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
top_k_embedding_count: The number of high-confidence embeddings to store. top_k_embedding_count: The number of high-confidence embeddings to store.
bottom_k_embedding_count: The number of low-confidence embeddings to store. bottom_k_embedding_count: The number of low-confidence embeddings to store.
num_shards: The number of output shards. num_shards: The number of output shards.
embedding_type: Which features to embed.
""" """
input_collection = ( input_collection = (
pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord( pipeline | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
...@@ -264,7 +272,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir, ...@@ -264,7 +272,7 @@ def construct_pipeline(pipeline, input_tfrecord, output_tfrecord, model_dir,
| 'AddKeys' >> beam.Map(add_keys)) | 'AddKeys' >> beam.Map(add_keys))
output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo( output_collection = input_collection | 'ExtractEmbedding' >> beam.ParDo(
GenerateEmbeddingDataFn(model_dir, top_k_embedding_count, GenerateEmbeddingDataFn(model_dir, top_k_embedding_count,
bottom_k_embedding_count)) bottom_k_embedding_count, embedding_type))
output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle() output_collection = output_collection | 'Reshuffle' >> beam.Reshuffle()
_ = output_collection | 'DropKeys' >> beam.Map( _ = output_collection | 'DropKeys' >> beam.Map(
drop_keys) | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord( drop_keys) | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
...@@ -315,6 +323,12 @@ def parse_args(argv): ...@@ -315,6 +323,12 @@ def parse_args(argv):
dest='num_shards', dest='num_shards',
default=0, default=0,
help='Number of output shards.') help='Number of output shards.')
parser.add_argument(
'--embedding_type',
dest='embedding_type',
default='final_box_features',
help='What features to embed, supports `final_box_features`, '
'`rpn_box_features`.')
beam_args, pipeline_args = parser.parse_known_args(argv) beam_args, pipeline_args = parser.parse_known_args(argv)
return beam_args, pipeline_args return beam_args, pipeline_args
...@@ -346,7 +360,8 @@ def main(argv=None, save_main_session=True): ...@@ -346,7 +360,8 @@ def main(argv=None, save_main_session=True):
args.embedding_model_dir, args.embedding_model_dir,
args.top_k_embedding_count, args.top_k_embedding_count,
args.bottom_k_embedding_count, args.bottom_k_embedding_count,
args.num_shards) args.num_shards,
args.embedding_type)
p.run() p.run()
......
...@@ -307,12 +307,14 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -307,12 +307,14 @@ class GenerateEmbeddingData(tf.test.TestCase):
top_k_embedding_count = 1 top_k_embedding_count = 1
bottom_k_embedding_count = 0 bottom_k_embedding_count = 0
num_shards = 1 num_shards = 1
embedding_type = 'final_box_features'
pipeline_options = beam.options.pipeline_options.PipelineOptions( pipeline_options = beam.options.pipeline_options.PipelineOptions(
runner='DirectRunner') runner='DirectRunner')
p = beam.Pipeline(options=pipeline_options) p = beam.Pipeline(options=pipeline_options)
generate_embedding_data.construct_pipeline( generate_embedding_data.construct_pipeline(
p, input_tfrecord, output_tfrecord, saved_model_path, p, input_tfrecord, output_tfrecord, saved_model_path,
top_k_embedding_count, bottom_k_embedding_count, num_shards) top_k_embedding_count, bottom_k_embedding_count, num_shards,
embedding_type)
p.run() p.run()
filenames = tf.io.gfile.glob( filenames = tf.io.gfile.glob(
output_tfrecord + '-?????-of-?????') output_tfrecord + '-?????-of-?????')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment