Commit df103208 authored by TF Object Detection Team's avatar TF Object Detection Team
Browse files

Merge pull request #8909 from kmindspark:singleframe2

PiperOrigin-RevId: 322234001
parents 363a36cd dbc211f2
...@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training): ...@@ -1094,8 +1094,12 @@ def get_reduce_to_frame_fn(input_reader_config, is_training):
num_frames = tf.cast( num_frames = tf.cast(
tf.shape(tensor_dict[fields.InputDataFields.source_id])[0], tf.shape(tensor_dict[fields.InputDataFields.source_id])[0],
dtype=tf.int32) dtype=tf.int32)
if input_reader_config.frame_index == -1:
frame_index = tf.random.uniform((), minval=0, maxval=num_frames, frame_index = tf.random.uniform((), minval=0, maxval=num_frames,
dtype=tf.int32) dtype=tf.int32)
else:
frame_index = tf.constant(input_reader_config.frame_index,
dtype=tf.int32)
out_tensor_dict = {} out_tensor_dict = {}
for key in tensor_dict: for key in tensor_dict:
if key in fields.SEQUENCE_FIELDS: if key in fields.SEQUENCE_FIELDS:
......
...@@ -61,7 +61,7 @@ def _get_configs_for_model(model_name): ...@@ -61,7 +61,7 @@ def _get_configs_for_model(model_name):
configs, kwargs_dict=override_dict) configs, kwargs_dict=override_dict)
def _get_configs_for_model_sequence_example(model_name): def _get_configs_for_model_sequence_example(model_name, frame_index=-1):
"""Returns configurations for model.""" """Returns configurations for model."""
fname = os.path.join(tf.resource_loader.get_data_files_path(), fname = os.path.join(tf.resource_loader.get_data_files_path(),
'test_data/' + model_name + '.config') 'test_data/' + model_name + '.config')
...@@ -74,7 +74,8 @@ def _get_configs_for_model_sequence_example(model_name): ...@@ -74,7 +74,8 @@ def _get_configs_for_model_sequence_example(model_name):
override_dict = { override_dict = {
'train_input_path': data_path, 'train_input_path': data_path,
'eval_input_path': data_path, 'eval_input_path': data_path,
'label_map_path': label_map_path 'label_map_path': label_map_path,
'frame_index': frame_index
} }
return config_util.merge_external_params_with_configs( return config_util.merge_external_params_with_configs(
configs, kwargs_dict=override_dict) configs, kwargs_dict=override_dict)
...@@ -312,6 +313,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase): ...@@ -312,6 +313,46 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf.float32, tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype) labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index(
self, train_batch_size=8):
"""Tests the training input function for FasterRcnnResnet50."""
configs = _get_configs_for_model_sequence_example(
'context_rcnn_camera_trap', frame_index=2)
model_config = configs['model']
train_config = configs['train_config']
train_config.batch_size = train_batch_size
train_input_fn = inputs.create_train_input_fn(
train_config, configs['train_input_config'], model_config)
features, labels = _make_initializable_iterator(train_input_fn()).get_next()
self.assertAllEqual([train_batch_size, 640, 640, 3],
features[fields.InputDataFields.image].shape.as_list())
self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
self.assertAllEqual([train_batch_size],
features[inputs.HASH_KEY].shape.as_list())
self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
self.assertAllEqual(
[train_batch_size, 100, 4],
labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_boxes].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_classes].dtype)
self.assertAllEqual(
[train_batch_size, 100],
labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
self.assertAllEqual(
[train_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
self.assertEqual(
tf.float32,
labels[fields.InputDataFields.groundtruth_confidences].dtype)
def test_ssd_inceptionV2_train_input(self): def test_ssd_inceptionV2_train_input(self):
"""Tests the training input function for SSDInceptionV2.""" """Tests the training input function for SSDInceptionV2."""
configs = _get_configs_for_model('ssd_inception_v2_pets') configs = _get_configs_for_model('ssd_inception_v2_pets')
......
...@@ -31,7 +31,7 @@ enum InputType { ...@@ -31,7 +31,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 32 // Next id: 33
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -133,6 +133,10 @@ message InputReader { ...@@ -133,6 +133,10 @@ message InputReader {
// Whether input data type is tf.Examples or tf.SequenceExamples // Whether input data type is tf.Examples or tf.SequenceExamples
optional InputType input_type = 30 [default = TF_EXAMPLE]; optional InputType input_type = 30 [default = TF_EXAMPLE];
// Which frame to choose from the input if Sequence Example. -1 indicates
// random choice.
optional int32 frame_index = 32 [default = -1];
oneof input_reader { oneof input_reader {
TFRecordInputReader tf_record_input_reader = 8; TFRecordInputReader tf_record_input_reader = 8;
ExternalInputReader external_input_reader = 9; ExternalInputReader external_input_reader = 9;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment