Commit f6b4cbcd authored by Sara Beery's avatar Sara Beery Committed by TF Object Detection Team
Browse files

Piping image id lists for context features for attention visualization

DEFAULT_VALUE_OK=setting prefetch batches default to 2 as using -1 can cause memory issues for some models such as context r-cnn that use sequence examples.

PiperOrigin-RevId: 345682268
parent 78f0e355
...@@ -65,7 +65,8 @@ def build(input_reader_config): ...@@ -65,7 +65,8 @@ def build(input_reader_config):
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features,
load_context_image_ids=input_reader_config.load_context_image_ids)
return decoder return decoder
raise ValueError('Unsupported input_type in config.') raise ValueError('Unsupported input_type in config.')
......
...@@ -85,7 +85,8 @@ def build(input_reader_config): ...@@ -85,7 +85,8 @@ def build(input_reader_config):
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features,
load_context_image_ids=input_reader_config.load_context_image_ids)
return decoder.decode(string_tensor) return decoder.decode(string_tensor)
raise ValueError('Unsupported input_type.') raise ValueError('Unsupported input_type.')
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -89,6 +89,8 @@ class InputDataFields(object): ...@@ -89,6 +89,8 @@ class InputDataFields(object):
context_features, used for reshaping. context_features, used for reshaping.
valid_context_size: the valid context size, used in filtering the padded valid_context_size: the valid context size, used in filtering the padded
context features. context features.
context_features_image_id_list: the list of image source ids corresponding
to the features in context_features
image_format: format for the images, used to decode image_format: format for the images, used to decode
image_height: height of images, used to decode image_height: height of images, used to decode
image_width: width of images, used to decode image_width: width of images, used to decode
...@@ -136,6 +138,7 @@ class InputDataFields(object): ...@@ -136,6 +138,7 @@ class InputDataFields(object):
context_features = 'context_features' context_features = 'context_features'
context_feature_length = 'context_feature_length' context_feature_length = 'context_feature_length'
valid_context_size = 'valid_context_size' valid_context_size = 'valid_context_size'
context_features_image_id_list = 'context_features_image_id_list'
image_timestamps = 'image_timestamps' image_timestamps = 'image_timestamps'
image_format = 'image_format' image_format = 'image_format'
image_height = 'image_height' image_height = 'image_height'
......
...@@ -117,11 +117,13 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -117,11 +117,13 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
Context R-CNN (see https://arxiv.org/abs/1912.03538): Context R-CNN (see https://arxiv.org/abs/1912.03538):
'image/context_features' 'image/context_features'
'image/context_feature_length' 'image/context_feature_length'
'image/context_features_image_id_list'
""" """
def __init__(self, def __init__(self,
label_map_proto_file, label_map_proto_file,
load_context_features=False, load_context_features=False,
load_context_image_ids=False,
use_display_name=False, use_display_name=False,
fully_annotated=False): fully_annotated=False):
"""Constructs `TfSequenceExampleDecoder` object. """Constructs `TfSequenceExampleDecoder` object.
...@@ -134,6 +136,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -134,6 +136,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
load_context_features: Whether to load information from context_features, load_context_features: Whether to load information from context_features,
to provide additional context to a detection model for training and/or to provide additional context to a detection model for training and/or
inference inference
load_context_image_ids: Whether to load the corresponding image ids for
the context_features in order to visualize attention.
use_display_name: whether or not to use the `display_name` for label use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is mapping (instead of `name`). Only used if label_map_proto_file is
provided. provided.
...@@ -207,6 +211,16 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -207,6 +211,16 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
tf.FixedLenFeature((), tf.int64)) tf.FixedLenFeature((), tf.int64))
self._items_to_handlers[fields.InputDataFields.context_feature_length] = ( self._items_to_handlers[fields.InputDataFields.context_feature_length] = (
slim_example_decoder.Tensor('image/context_feature_length')) slim_example_decoder.Tensor('image/context_feature_length'))
if load_context_image_ids:
self._context_keys_to_features['image/context_features_image_id_list'] = (
tf.VarLenFeature(dtype=tf.string))
self._items_to_handlers[
fields.InputDataFields.context_features_image_id_list] = (
slim_example_decoder.Tensor(
'image/context_features_image_id_list',
default_value=''))
self._fully_annotated = fully_annotated self._fully_annotated = fully_annotated
def decode(self, tf_seq_example_string_tensor): def decode(self, tf_seq_example_string_tensor):
...@@ -239,6 +253,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -239,6 +253,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
the length of each feature in context_features the length of each feature in context_features
fields.InputDataFields.image: a [num_frames] string tensor with fields.InputDataFields.image: a [num_frames] string tensor with
the encoded images. the encoded images.
fields.inputDataFields.context_features_image_id_list: a 1D vector
of shape [num_context_features] containing string tensors.
""" """
serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[]) serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFSequenceExampleDecoder( decoder = slim_example_decoder.TFSequenceExampleDecoder(
......
...@@ -120,6 +120,145 @@ class TfSequenceExampleDecoderTest(test_case.TestCase): ...@@ -120,6 +120,145 @@ class TfSequenceExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(expected_groundtruth_classes, self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes]) tensor_dict_out[flds.groundtruth_classes])
def test_decode_sequence_example_context(self):
num_frames = 4
image_height = 20
image_width = 30
expected_groundtruth_boxes = [
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
[[0.2, 0.2, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
]
expected_groundtruth_classes = [
[-1, -1],
[-1, 1],
[1, 2],
[-1, -1]
]
expected_context_features = np.array(
[[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]], dtype=np.float32)
flds = fields.InputDataFields
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
def graph_fn():
label_map_proto_file = os.path.join(self.get_temp_dir(), 'labelmap.pbtxt')
self._create_label_map(label_map_proto_file)
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_features=True)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_format='JPEG',
image_source_ids=[str(i) for i in range(num_frames)],
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[0., 0., 1., 1.]], # Frame 0.
[[0.2, 0.2, 1., 1.],
[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.], # Frame 2.
[0.1, 0.1, 0.2, 0.2]],
[[]], # Frame 3.
],
label_strings=[
['fox'], # Frame 0. Fox will be filtered out.
['fox', 'dog'], # Frame 1. Fox will be filtered out.
['dog', 'cat'], # Frame 2.
[], # Frame 3
],
context_features=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
context_feature_length=[3],
context_features_image_id_list=[b'im_1', b'im_2']
).SerializeToString()
example_string_tensor = tf.convert_to_tensor(sequence_example_serialized)
return decoder.decode(example_string_tensor)
tensor_dict_out = self.execute(graph_fn, [])
self.assertAllClose(expected_groundtruth_boxes,
tensor_dict_out[flds.groundtruth_boxes])
self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes])
self.assertAllClose(expected_context_features,
tensor_dict_out[flds.context_features])
def test_decode_sequence_example_context_image_id_list(self):
num_frames = 4
image_height = 20
image_width = 30
expected_groundtruth_boxes = [
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
[[0.2, 0.2, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
]
expected_groundtruth_classes = [
[-1, -1],
[-1, 1],
[1, 2],
[-1, -1]
]
expected_context_image_ids = [b'im_1', b'im_2']
flds = fields.InputDataFields
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
def graph_fn():
label_map_proto_file = os.path.join(self.get_temp_dir(), 'labelmap.pbtxt')
self._create_label_map(label_map_proto_file)
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_image_ids=True)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_format='JPEG',
image_source_ids=[str(i) for i in range(num_frames)],
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[0., 0., 1., 1.]], # Frame 0.
[[0.2, 0.2, 1., 1.],
[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.], # Frame 2.
[0.1, 0.1, 0.2, 0.2]],
[[]], # Frame 3.
],
label_strings=[
['fox'], # Frame 0. Fox will be filtered out.
['fox', 'dog'], # Frame 1. Fox will be filtered out.
['dog', 'cat'], # Frame 2.
[], # Frame 3
],
context_features=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
context_feature_length=[3],
context_features_image_id_list=[b'im_1', b'im_2']
).SerializeToString()
example_string_tensor = tf.convert_to_tensor(sequence_example_serialized)
return decoder.decode(example_string_tensor)
tensor_dict_out = self.execute(graph_fn, [])
self.assertAllClose(expected_groundtruth_boxes,
tensor_dict_out[flds.groundtruth_boxes])
self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes])
self.assertAllEqual(expected_context_image_ids,
tensor_dict_out[flds.context_features_image_id_list])
def test_decode_sequence_example_negative_clip(self): def test_decode_sequence_example_negative_clip(self):
num_frames = 4 num_frames = 4
image_height = 20 image_height = 20
......
...@@ -171,7 +171,10 @@ def make_sequence_example(dataset_name, ...@@ -171,7 +171,10 @@ def make_sequence_example(dataset_name,
detection_bboxes=None, detection_bboxes=None,
detection_classes=None, detection_classes=None,
detection_scores=None, detection_scores=None,
use_strs_for_source_id=False): use_strs_for_source_id=False,
context_features=None,
context_feature_length=None,
context_features_image_id_list=None):
"""Constructs tf.SequenceExamples. """Constructs tf.SequenceExamples.
Args: Args:
...@@ -203,6 +206,12 @@ def make_sequence_example(dataset_name, ...@@ -203,6 +206,12 @@ def make_sequence_example(dataset_name,
each frame. each frame.
use_strs_for_source_id: (Optional) Whether to write the source IDs as use_strs_for_source_id: (Optional) Whether to write the source IDs as
strings rather than byte lists of characters. strings rather than byte lists of characters.
context_features: (Optional) A list or numpy array of features to use in
Context R-CNN, of length num_context_features * context_feature_length.
context_feature_length: (Optional) The length of each context feature, used
for reshaping.
context_features_image_id_list: (Optional) A list of image ids of length
num_context_features corresponding to the context features.
Returns: Returns:
A tf.train.SequenceExample. A tf.train.SequenceExample.
...@@ -273,6 +282,16 @@ def make_sequence_example(dataset_name, ...@@ -273,6 +282,16 @@ def make_sequence_example(dataset_name,
feature_list['predicted/region/label/confidence'] = sequence_float_feature( feature_list['predicted/region/label/confidence'] = sequence_float_feature(
detection_scores) detection_scores)
if context_features is not None:
context_dict['image/context_features'] = context_float_feature(
context_features)
if context_feature_length is not None:
context_dict['image/context_feature_length'] = context_int64_feature(
context_feature_length)
if context_features_image_id_list is not None:
context_dict['image/context_features_image_id_list'] = (
context_bytes_feature(context_features_image_id_list))
context = tf.train.Features(feature=context_dict) context = tf.train.Features(feature=context_dict)
feature_lists = tf.train.FeatureLists(feature_list=feature_list) feature_lists = tf.train.FeatureLists(feature_list=feature_list)
......
...@@ -204,6 +204,123 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -204,6 +204,123 @@ class SeqExampleUtilTest(tf.test.TestCase):
[], [],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:]) seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def test_make_labeled_example_with_context_features(self):
num_frames = 2
image_height = 100
image_width = 200
dataset_name = b'unlabeled_dataset'
video_id = b'video_000'
labels = [b'dog', b'cat']
images = tf.cast(tf.random.uniform(
[num_frames, image_height, image_width, 3],
maxval=256,
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
encoded_images = self.materialize_tensors(encoded_images_list)
timestamps = [100000, 110000]
is_annotated = [1, 0]
bboxes = [
np.array([[0., 0., 0., 0.],
[0., 0., 1., 1.]], dtype=np.float32),
np.zeros([0, 4], dtype=np.float32)
]
label_strings = [
np.array(labels),
np.array([])
]
context_features = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
context_feature_length = [3]
context_features_image_id_list = [b'im_1', b'im_2']
seq_example = seq_example_util.make_sequence_example(
dataset_name=dataset_name,
video_id=video_id,
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
timestamps=timestamps,
is_annotated=is_annotated,
bboxes=bboxes,
label_strings=label_strings,
context_features=context_features,
context_feature_length=context_feature_length,
context_features_image_id_list=context_features_image_id_list)
context_feature_dict = seq_example.context.feature
self.assertEqual(
dataset_name,
context_feature_dict['example/dataset_name'].bytes_list.value[0])
self.assertEqual(
timestamps[0],
context_feature_dict['clip/start/timestamp'].int64_list.value[0])
self.assertEqual(
timestamps[-1],
context_feature_dict['clip/end/timestamp'].int64_list.value[0])
self.assertEqual(
num_frames,
context_feature_dict['clip/frames'].int64_list.value[0])
self.assertAllClose(
context_features,
context_feature_dict['image/context_features'].float_list.value[:])
self.assertEqual(
context_feature_length[0],
context_feature_dict[
'image/context_feature_length'].int64_list.value[0])
self.assertEqual(
context_features_image_id_list,
context_feature_dict[
'image/context_features_image_id_list'].bytes_list.value[:])
seq_feature_dict = seq_example.feature_lists.feature_list
self.assertLen(
seq_feature_dict['image/encoded'].feature[:],
num_frames)
actual_timestamps = [
feature.int64_list.value[0] for feature
in seq_feature_dict['image/timestamp'].feature]
self.assertAllEqual(timestamps, actual_timestamps)
# Frame 0.
self.assertAllEqual(
is_annotated[0],
seq_feature_dict['region/is_annotated'].feature[0].int64_list.value[0])
self.assertAllClose(
[0., 0.],
seq_feature_dict['region/bbox/ymin'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 0.],
seq_feature_dict['region/bbox/xmin'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 1.],
seq_feature_dict['region/bbox/ymax'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 1.],
seq_feature_dict['region/bbox/xmax'].feature[0].float_list.value[:])
self.assertAllEqual(
labels,
seq_feature_dict['region/label/string'].feature[0].bytes_list.value[:])
# Frame 1.
self.assertAllEqual(
is_annotated[1],
seq_feature_dict['region/is_annotated'].feature[1].int64_list.value[0])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/ymin'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/xmin'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/ymax'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/xmax'].feature[1].float_list.value[:])
self.assertAllEqual(
[],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def test_make_labeled_example_with_predictions(self): def test_make_labeled_example_with_predictions(self):
num_frames = 2 num_frames = 2
image_height = 100 image_height = 100
......
...@@ -419,7 +419,6 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -419,7 +419,6 @@ def pad_input_data_to_static_shapes(tensor_dict,
max_num_context_features is not specified and context_features is in the max_num_context_features is not specified and context_features is in the
tensor dict. tensor dict.
""" """
if not spatial_image_shape or spatial_image_shape == [-1, -1]: if not spatial_image_shape or spatial_image_shape == [-1, -1]:
height, width = None, None height, width = None, None
else: else:
...@@ -539,11 +538,14 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -539,11 +538,14 @@ def pad_input_data_to_static_shapes(tensor_dict,
padding_shapes[input_fields.context_features] = padding_shape padding_shapes[input_fields.context_features] = padding_shape
tensor_shape = tf.shape( tensor_shape = tf.shape(
tensor_dict[input_fields.context_features]) tensor_dict[fields.InputDataFields.context_features])
tensor_dict[input_fields.valid_context_size] = tensor_shape[0] tensor_dict[fields.InputDataFields.valid_context_size] = tensor_shape[0]
padding_shapes[input_fields.valid_context_size] = [] padding_shapes[fields.InputDataFields.valid_context_size] = []
if input_fields.context_feature_length in tensor_dict: if fields.InputDataFields.context_feature_length in tensor_dict:
padding_shapes[input_fields.context_feature_length] = [] padding_shapes[fields.InputDataFields.context_feature_length] = []
if fields.InputDataFields.context_features_image_id_list in tensor_dict:
padding_shapes[fields.InputDataFields.context_features_image_id_list] = [
max_num_context_features]
if input_fields.is_annotated in tensor_dict: if input_fields.is_annotated in tensor_dict:
padding_shapes[input_fields.is_annotated] = [] padding_shapes[input_fields.is_annotated] = []
...@@ -709,6 +711,9 @@ def _get_features_dict(input_dict, include_source_id=False): ...@@ -709,6 +711,9 @@ def _get_features_dict(input_dict, include_source_id=False):
if fields.InputDataFields.valid_context_size in input_dict: if fields.InputDataFields.valid_context_size in input_dict:
features[fields.InputDataFields.valid_context_size] = input_dict[ features[fields.InputDataFields.valid_context_size] = input_dict[
fields.InputDataFields.valid_context_size] fields.InputDataFields.valid_context_size]
if fields.InputDataFields.context_features_image_id_list in input_dict:
features[fields.InputDataFields.context_features_image_id_list] = (
input_dict[fields.InputDataFields.context_features_image_id_list])
return features return features
......
...@@ -313,6 +313,47 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase): ...@@ -313,6 +313,47 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf.float32, tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype) labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_eval_input_with_sequence_example_image_id_list(
self, eval_batch_size=8):
"""Tests the eval input function for FasterRcnnResnet50."""
configs = _get_configs_for_model_sequence_example(
'context_rcnn_camera_trap')
model_config = configs['model']
eval_config = configs['eval_config']
eval_config.batch_size = eval_batch_size
eval_input_config = configs['eval_input_configs'][0]
eval_input_config.load_context_image_ids = True
eval_input_fn = inputs.create_eval_input_fn(
eval_config, eval_input_config, model_config)
features, labels = _make_initializable_iterator(eval_input_fn()).get_next()
self.assertAllEqual([eval_batch_size, 640, 640, 3],
features[fields.InputDataFields.image].shape.as_list())
self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
self.assertAllEqual(
[eval_batch_size, 640, 640, 3],
features[fields.InputDataFields.original_image].shape.as_list())
self.assertEqual(tf.uint8,
features[fields.InputDataFields.original_image].dtype)
self.assertAllEqual([eval_batch_size],
features[inputs.HASH_KEY].shape.as_list())
self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
self.assertAllEqual(
[eval_batch_size, 100, 4],
labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_boxes].dtype)
self.assertAllEqual(
[eval_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_classes].dtype)
self.assertAllEqual(
[eval_batch_size, 100],
labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
self.assertEqual(
tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index( def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index(
self, train_batch_size=8): self, train_batch_size=8):
"""Tests the training input function for FasterRcnnResnet50.""" """Tests the training input function for FasterRcnnResnet50."""
......
...@@ -30,7 +30,7 @@ enum InputType { ...@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 36 // Next id: 37
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -75,7 +75,7 @@ message InputReader { ...@@ -75,7 +75,7 @@ message InputReader {
// to a small constant and increment linearly until the improvements become // to a small constant and increment linearly until the improvements become
// marginal or you exceed your cpu memory budget. Setting this to -1, // marginal or you exceed your cpu memory budget. Setting this to -1,
// automatically tunes this value for you. // automatically tunes this value for you.
optional int32 num_prefetch_batches = 20 [default = -1]; optional int32 num_prefetch_batches = 20 [default = 2];
// Maximum number of records to keep in reader queue. // Maximum number of records to keep in reader queue.
optional uint32 queue_capacity = 3 [default = 2000, deprecated = true]; optional uint32 queue_capacity = 3 [default = 2000, deprecated = true];
...@@ -118,6 +118,9 @@ message InputReader { ...@@ -118,6 +118,9 @@ message InputReader {
// Whether to load context features from the dataset. // Whether to load context features from the dataset.
optional bool load_context_features = 25 [default = false]; optional bool load_context_features = 25 [default = false];
// Whether to load context image ids from the dataset.
optional bool load_context_image_ids = 36 [default = false];
// Whether to load groundtruth instance masks. // Whether to load groundtruth instance masks.
optional bool load_instance_masks = 7 [default = false]; optional bool load_instance_masks = 7 [default = false];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment