Commit dbc211f2 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

adding option to select specific frame index

parent 57c08a08
...@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training): ...@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training):
num_frames = tf.cast( num_frames = tf.cast(
tf.shape(tensor_dict[fields.InputDataFields.source_id])[0], tf.shape(tensor_dict[fields.InputDataFields.source_id])[0],
dtype=tf.int32) dtype=tf.int32)
frame_index = tf.random.uniform((), minval=0, maxval=num_frames, if input_reader_config.frame_index == -1:
dtype=tf.int32) frame_index = tf.random.uniform((), minval=0, maxval=num_frames,
dtype=tf.int32)
else:
frame_index = tf.constant(input_reader_config.frame_index,
dtype=tf.int32)
out_tensor_dict = {} out_tensor_dict = {}
for key in tensor_dict: for key in tensor_dict:
if key in fields.SEQUENCE_FIELDS: if key in fields.SEQUENCE_FIELDS:
......
...@@ -312,6 +312,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase): ...@@ -312,6 +312,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf.float32, tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype) labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index(
self, train_batch_size=8):
"""Tests the training input function for FasterRcnnResnet50."""
configs = _get_configs_for_model_sequence_example(
'context_rcnn_camera_trap', frame_index=2)
model_config = configs['model']
train_config = configs['train_config']
train_config.batch_size = train_batch_size
train_input_fn = inputs.create_train_input_fn(
train_config, configs['train_input_config'], model_config)
features, labels = _make_initializable_iterator(train_input_fn()).get_next()
self.assertAllEqual([train_batch_size, 640, 640, 3],
features[fields.InputDataFields.image].shape.as_list())
self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
self.assertAllEqual([train_batch_size],
features[inputs.HASH_KEY].shape.as_list())
self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
self.assertAllEqual(
[train_batch_size, 100, 4],
labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_boxes].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_classes].dtype)
self.assertAllEqual(
[train_batch_size, 100],
labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
self.assertEqual(
tf.float32,
labels[fields.InputDataFields.groundtruth_confidences].dtype)
def test_ssd_inceptionV2_train_input(self): def test_ssd_inceptionV2_train_input(self):
"""Tests the training input function for SSDInceptionV2.""" """Tests the training input function for SSDInceptionV2."""
configs = _get_configs_for_model('ssd_inception_v2_pets') configs = _get_configs_for_model('ssd_inception_v2_pets')
......
syntax = "proto2"; syntax = "proto2";
package object_detection.protos; package object_detection.protos;
...@@ -31,7 +32,7 @@ enum InputType { ...@@ -31,7 +32,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 32 // Next id: 33
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -133,6 +134,9 @@ message InputReader { ...@@ -133,6 +134,9 @@ message InputReader {
// Whether input data type is tf.Examples or tf.SequenceExamples // Whether input data type is tf.Examples or tf.SequenceExamples
optional InputType input_type = 30 [default = TF_EXAMPLE]; optional InputType input_type = 30 [default = TF_EXAMPLE];
// Which frame to choose from the input if Sequence Example. -1 indicates random choice.
optional int32 frame_index = 32 [default = -1];
oneof input_reader { oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8; TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9; ExternalInputReader external_input_reader = 9;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment