Commit df103208 authored by TF Object Detection Team's avatar TF Object Detection Team
Browse files

Merge pull request #8909 from kmindspark:singleframe2

PiperOrigin-RevId: 322234001
parents 363a36cd dbc211f2
......@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training):
num_frames = tf.cast(
tf.shape(tensor_dict[fields.InputDataFields.source_id])[0],
dtype=tf.int32)
frame_index = tf.random.uniform((), minval=0, maxval=num_frames,
dtype=tf.int32)
if input_reader_config.frame_index == -1:
frame_index = tf.random.uniform((), minval=0, maxval=num_frames,
dtype=tf.int32)
else:
frame_index = tf.constant(input_reader_config.frame_index,
dtype=tf.int32)
out_tensor_dict = {}
for key in tensor_dict:
if key in fields.SEQUENCE_FIELDS:
......
......@@ -61,7 +61,7 @@ def _get_configs_for_model(model_name):
configs, kwargs_dict=override_dict)
def _get_configs_for_model_sequence_example(model_name):
def _get_configs_for_model_sequence_example(model_name, frame_index=-1):
"""Returns configurations for model."""
fname = os.path.join(tf.resource_loader.get_data_files_path(),
'test_data/' + model_name + '.config')
......@@ -74,7 +74,8 @@ def _get_configs_for_model_sequence_example(model_name):
override_dict = {
'train_input_path': data_path,
'eval_input_path': data_path,
'label_map_path': label_map_path
'label_map_path': label_map_path,
'frame_index': frame_index
}
return config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict)
......@@ -312,6 +313,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index(
self, train_batch_size=8):
"""Tests the training input function for FasterRcnnResnet50."""
configs = _get_configs_for_model_sequence_example(
'context_rcnn_camera_trap', frame_index=2)
model_config = configs['model']
train_config = configs['train_config']
train_config.batch_size = train_batch_size
train_input_fn = inputs.create_train_input_fn(
train_config, configs['train_input_config'], model_config)
features, labels = _make_initializable_iterator(train_input_fn()).get_next()
self.assertAllEqual([train_batch_size, 640, 640, 3],
features[fields.InputDataFields.image].shape.as_list())
self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
self.assertAllEqual([train_batch_size],
features[inputs.HASH_KEY].shape.as_list())
self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
self.assertAllEqual(
[train_batch_size, 100, 4],
labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_boxes].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_classes].dtype)
self.assertAllEqual(
[train_batch_size, 100],
labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
self.assertEqual(
tf.float32,
labels[fields.InputDataFields.groundtruth_confidences].dtype)
def test_ssd_inceptionV2_train_input(self):
"""Tests the training input function for SSDInceptionV2."""
configs = _get_configs_for_model('ssd_inception_v2_pets')
......
......@@ -31,7 +31,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
}
// Next id: 32
// Next id: 33
message InputReader {
// Name of input reader. Typically used to describe the dataset that is read
// by this input reader.
......@@ -133,6 +133,10 @@ message InputReader {
// Whether input data type is tf.Examples or tf.SequenceExamples
optional InputType input_type = 30 [default = TF_EXAMPLE];
// Which frame to choose from the input if Sequence Example. -1 indicates
// random choice.
optional int32 frame_index = 32 [default = -1];
oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment