"...text-generation-inference.git" did not exist on "9192de57cc6802508db41c489a9e0ee9df569de5"
Commit f6b4cbcd authored by Sara Beery's avatar Sara Beery Committed by TF Object Detection Team
Browse files

Piping image id lists for context features for attention visualization

DEFAULT_VALUE_OK=setting prefetch batches default to 2 as using -1 can cause memory issues for some models such as context r-cnn that use sequence examples.

PiperOrigin-RevId: 345682268
parent 78f0e355
...@@ -65,7 +65,8 @@ def build(input_reader_config): ...@@ -65,7 +65,8 @@ def build(input_reader_config):
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features,
load_context_image_ids=input_reader_config.load_context_image_ids)
return decoder return decoder
raise ValueError('Unsupported input_type in config.') raise ValueError('Unsupported input_type in config.')
......
...@@ -85,7 +85,8 @@ def build(input_reader_config): ...@@ -85,7 +85,8 @@ def build(input_reader_config):
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features,
load_context_image_ids=input_reader_config.load_context_image_ids)
return decoder.decode(string_tensor) return decoder.decode(string_tensor)
raise ValueError('Unsupported input_type.') raise ValueError('Unsupported input_type.')
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -89,6 +89,8 @@ class InputDataFields(object): ...@@ -89,6 +89,8 @@ class InputDataFields(object):
context_features, used for reshaping. context_features, used for reshaping.
valid_context_size: the valid context size, used in filtering the padded valid_context_size: the valid context size, used in filtering the padded
context features. context features.
context_features_image_id_list: the list of image source ids corresponding
to the features in context_features
image_format: format for the images, used to decode image_format: format for the images, used to decode
image_height: height of images, used to decode image_height: height of images, used to decode
image_width: width of images, used to decode image_width: width of images, used to decode
...@@ -136,6 +138,7 @@ class InputDataFields(object): ...@@ -136,6 +138,7 @@ class InputDataFields(object):
context_features = 'context_features' context_features = 'context_features'
context_feature_length = 'context_feature_length' context_feature_length = 'context_feature_length'
valid_context_size = 'valid_context_size' valid_context_size = 'valid_context_size'
context_features_image_id_list = 'context_features_image_id_list'
image_timestamps = 'image_timestamps' image_timestamps = 'image_timestamps'
image_format = 'image_format' image_format = 'image_format'
image_height = 'image_height' image_height = 'image_height'
......
...@@ -117,11 +117,13 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -117,11 +117,13 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
Context R-CNN (see https://arxiv.org/abs/1912.03538): Context R-CNN (see https://arxiv.org/abs/1912.03538):
'image/context_features' 'image/context_features'
'image/context_feature_length' 'image/context_feature_length'
'image/context_features_image_id_list'
""" """
def __init__(self, def __init__(self,
label_map_proto_file, label_map_proto_file,
load_context_features=False, load_context_features=False,
load_context_image_ids=False,
use_display_name=False, use_display_name=False,
fully_annotated=False): fully_annotated=False):
"""Constructs `TfSequenceExampleDecoder` object. """Constructs `TfSequenceExampleDecoder` object.
...@@ -134,6 +136,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -134,6 +136,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
load_context_features: Whether to load information from context_features, load_context_features: Whether to load information from context_features,
to provide additional context to a detection model for training and/or to provide additional context to a detection model for training and/or
inference inference
load_context_image_ids: Whether to load the corresponding image ids for
the context_features in order to visualize attention.
use_display_name: whether or not to use the `display_name` for label use_display_name: whether or not to use the `display_name` for label
mapping (instead of `name`). Only used if label_map_proto_file is mapping (instead of `name`). Only used if label_map_proto_file is
provided. provided.
...@@ -207,6 +211,16 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -207,6 +211,16 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
tf.FixedLenFeature((), tf.int64)) tf.FixedLenFeature((), tf.int64))
self._items_to_handlers[fields.InputDataFields.context_feature_length] = ( self._items_to_handlers[fields.InputDataFields.context_feature_length] = (
slim_example_decoder.Tensor('image/context_feature_length')) slim_example_decoder.Tensor('image/context_feature_length'))
if load_context_image_ids:
self._context_keys_to_features['image/context_features_image_id_list'] = (
tf.VarLenFeature(dtype=tf.string))
self._items_to_handlers[
fields.InputDataFields.context_features_image_id_list] = (
slim_example_decoder.Tensor(
'image/context_features_image_id_list',
default_value=''))
self._fully_annotated = fully_annotated self._fully_annotated = fully_annotated
def decode(self, tf_seq_example_string_tensor): def decode(self, tf_seq_example_string_tensor):
...@@ -239,6 +253,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder): ...@@ -239,6 +253,8 @@ class TfSequenceExampleDecoder(data_decoder.DataDecoder):
the length of each feature in context_features the length of each feature in context_features
fields.InputDataFields.image: a [num_frames] string tensor with fields.InputDataFields.image: a [num_frames] string tensor with
the encoded images. the encoded images.
fields.inputDataFields.context_features_image_id_list: a 1D vector
of shape [num_context_features] containing string tensors.
""" """
serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[]) serialized_example = tf.reshape(tf_seq_example_string_tensor, shape=[])
decoder = slim_example_decoder.TFSequenceExampleDecoder( decoder = slim_example_decoder.TFSequenceExampleDecoder(
......
...@@ -120,6 +120,145 @@ class TfSequenceExampleDecoderTest(test_case.TestCase): ...@@ -120,6 +120,145 @@ class TfSequenceExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(expected_groundtruth_classes, self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes]) tensor_dict_out[flds.groundtruth_classes])
def test_decode_sequence_example_context(self):
num_frames = 4
image_height = 20
image_width = 30
expected_groundtruth_boxes = [
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
[[0.2, 0.2, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
]
expected_groundtruth_classes = [
[-1, -1],
[-1, 1],
[1, 2],
[-1, -1]
]
expected_context_features = np.array(
[[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]], dtype=np.float32)
flds = fields.InputDataFields
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
def graph_fn():
label_map_proto_file = os.path.join(self.get_temp_dir(), 'labelmap.pbtxt')
self._create_label_map(label_map_proto_file)
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_features=True)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_format='JPEG',
image_source_ids=[str(i) for i in range(num_frames)],
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[0., 0., 1., 1.]], # Frame 0.
[[0.2, 0.2, 1., 1.],
[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.], # Frame 2.
[0.1, 0.1, 0.2, 0.2]],
[[]], # Frame 3.
],
label_strings=[
['fox'], # Frame 0. Fox will be filtered out.
['fox', 'dog'], # Frame 1. Fox will be filtered out.
['dog', 'cat'], # Frame 2.
[], # Frame 3
],
context_features=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
context_feature_length=[3],
context_features_image_id_list=[b'im_1', b'im_2']
).SerializeToString()
example_string_tensor = tf.convert_to_tensor(sequence_example_serialized)
return decoder.decode(example_string_tensor)
tensor_dict_out = self.execute(graph_fn, [])
self.assertAllClose(expected_groundtruth_boxes,
tensor_dict_out[flds.groundtruth_boxes])
self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes])
self.assertAllClose(expected_context_features,
tensor_dict_out[flds.context_features])
def test_decode_sequence_example_context_image_id_list(self):
num_frames = 4
image_height = 20
image_width = 30
expected_groundtruth_boxes = [
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
[[0.2, 0.2, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]],
[[0.0, 0.0, 1.0, 1.0], [0.1, 0.1, 0.2, 0.2]],
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
]
expected_groundtruth_classes = [
[-1, -1],
[-1, 1],
[1, 2],
[-1, -1]
]
expected_context_image_ids = [b'im_1', b'im_2']
flds = fields.InputDataFields
encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width)
def graph_fn():
label_map_proto_file = os.path.join(self.get_temp_dir(), 'labelmap.pbtxt')
self._create_label_map(label_map_proto_file)
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_image_ids=True)
sequence_example_serialized = seq_example_util.make_sequence_example(
dataset_name='video_dataset',
video_id='video',
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
image_format='JPEG',
image_source_ids=[str(i) for i in range(num_frames)],
is_annotated=[[1], [1], [1], [1]],
bboxes=[
[[0., 0., 1., 1.]], # Frame 0.
[[0.2, 0.2, 1., 1.],
[0., 0., 1., 1.]], # Frame 1.
[[0., 0., 1., 1.], # Frame 2.
[0.1, 0.1, 0.2, 0.2]],
[[]], # Frame 3.
],
label_strings=[
['fox'], # Frame 0. Fox will be filtered out.
['fox', 'dog'], # Frame 1. Fox will be filtered out.
['dog', 'cat'], # Frame 2.
[], # Frame 3
],
context_features=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
context_feature_length=[3],
context_features_image_id_list=[b'im_1', b'im_2']
).SerializeToString()
example_string_tensor = tf.convert_to_tensor(sequence_example_serialized)
return decoder.decode(example_string_tensor)
tensor_dict_out = self.execute(graph_fn, [])
self.assertAllClose(expected_groundtruth_boxes,
tensor_dict_out[flds.groundtruth_boxes])
self.assertAllEqual(expected_groundtruth_classes,
tensor_dict_out[flds.groundtruth_classes])
self.assertAllEqual(expected_context_image_ids,
tensor_dict_out[flds.context_features_image_id_list])
def test_decode_sequence_example_negative_clip(self): def test_decode_sequence_example_negative_clip(self):
num_frames = 4 num_frames = 4
image_height = 20 image_height = 20
......
...@@ -171,7 +171,10 @@ def make_sequence_example(dataset_name, ...@@ -171,7 +171,10 @@ def make_sequence_example(dataset_name,
detection_bboxes=None, detection_bboxes=None,
detection_classes=None, detection_classes=None,
detection_scores=None, detection_scores=None,
use_strs_for_source_id=False): use_strs_for_source_id=False,
context_features=None,
context_feature_length=None,
context_features_image_id_list=None):
"""Constructs tf.SequenceExamples. """Constructs tf.SequenceExamples.
Args: Args:
...@@ -203,6 +206,12 @@ def make_sequence_example(dataset_name, ...@@ -203,6 +206,12 @@ def make_sequence_example(dataset_name,
each frame. each frame.
use_strs_for_source_id: (Optional) Whether to write the source IDs as use_strs_for_source_id: (Optional) Whether to write the source IDs as
strings rather than byte lists of characters. strings rather than byte lists of characters.
context_features: (Optional) A list or numpy array of features to use in
Context R-CNN, of length num_context_features * context_feature_length.
context_feature_length: (Optional) The length of each context feature, used
for reshaping.
context_features_image_id_list: (Optional) A list of image ids of length
num_context_features corresponding to the context features.
Returns: Returns:
A tf.train.SequenceExample. A tf.train.SequenceExample.
...@@ -273,6 +282,16 @@ def make_sequence_example(dataset_name, ...@@ -273,6 +282,16 @@ def make_sequence_example(dataset_name,
feature_list['predicted/region/label/confidence'] = sequence_float_feature( feature_list['predicted/region/label/confidence'] = sequence_float_feature(
detection_scores) detection_scores)
if context_features is not None:
context_dict['image/context_features'] = context_float_feature(
context_features)
if context_feature_length is not None:
context_dict['image/context_feature_length'] = context_int64_feature(
context_feature_length)
if context_features_image_id_list is not None:
context_dict['image/context_features_image_id_list'] = (
context_bytes_feature(context_features_image_id_list))
context = tf.train.Features(feature=context_dict) context = tf.train.Features(feature=context_dict)
feature_lists = tf.train.FeatureLists(feature_list=feature_list) feature_lists = tf.train.FeatureLists(feature_list=feature_list)
......
...@@ -204,6 +204,123 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -204,6 +204,123 @@ class SeqExampleUtilTest(tf.test.TestCase):
[], [],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:]) seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def test_make_labeled_example_with_context_features(self):
num_frames = 2
image_height = 100
image_width = 200
dataset_name = b'unlabeled_dataset'
video_id = b'video_000'
labels = [b'dog', b'cat']
images = tf.cast(tf.random.uniform(
[num_frames, image_height, image_width, 3],
maxval=256,
dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list]
encoded_images = self.materialize_tensors(encoded_images_list)
timestamps = [100000, 110000]
is_annotated = [1, 0]
bboxes = [
np.array([[0., 0., 0., 0.],
[0., 0., 1., 1.]], dtype=np.float32),
np.zeros([0, 4], dtype=np.float32)
]
label_strings = [
np.array(labels),
np.array([])
]
context_features = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
context_feature_length = [3]
context_features_image_id_list = [b'im_1', b'im_2']
seq_example = seq_example_util.make_sequence_example(
dataset_name=dataset_name,
video_id=video_id,
encoded_images=encoded_images,
image_height=image_height,
image_width=image_width,
timestamps=timestamps,
is_annotated=is_annotated,
bboxes=bboxes,
label_strings=label_strings,
context_features=context_features,
context_feature_length=context_feature_length,
context_features_image_id_list=context_features_image_id_list)
context_feature_dict = seq_example.context.feature
self.assertEqual(
dataset_name,
context_feature_dict['example/dataset_name'].bytes_list.value[0])
self.assertEqual(
timestamps[0],
context_feature_dict['clip/start/timestamp'].int64_list.value[0])
self.assertEqual(
timestamps[-1],
context_feature_dict['clip/end/timestamp'].int64_list.value[0])
self.assertEqual(
num_frames,
context_feature_dict['clip/frames'].int64_list.value[0])
self.assertAllClose(
context_features,
context_feature_dict['image/context_features'].float_list.value[:])
self.assertEqual(
context_feature_length[0],
context_feature_dict[
'image/context_feature_length'].int64_list.value[0])
self.assertEqual(
context_features_image_id_list,
context_feature_dict[
'image/context_features_image_id_list'].bytes_list.value[:])
seq_feature_dict = seq_example.feature_lists.feature_list
self.assertLen(
seq_feature_dict['image/encoded'].feature[:],
num_frames)
actual_timestamps = [
feature.int64_list.value[0] for feature
in seq_feature_dict['image/timestamp'].feature]
self.assertAllEqual(timestamps, actual_timestamps)
# Frame 0.
self.assertAllEqual(
is_annotated[0],
seq_feature_dict['region/is_annotated'].feature[0].int64_list.value[0])
self.assertAllClose(
[0., 0.],
seq_feature_dict['region/bbox/ymin'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 0.],
seq_feature_dict['region/bbox/xmin'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 1.],
seq_feature_dict['region/bbox/ymax'].feature[0].float_list.value[:])
self.assertAllClose(
[0., 1.],
seq_feature_dict['region/bbox/xmax'].feature[0].float_list.value[:])
self.assertAllEqual(
labels,
seq_feature_dict['region/label/string'].feature[0].bytes_list.value[:])
# Frame 1.
self.assertAllEqual(
is_annotated[1],
seq_feature_dict['region/is_annotated'].feature[1].int64_list.value[0])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/ymin'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/xmin'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/ymax'].feature[1].float_list.value[:])
self.assertAllClose(
[],
seq_feature_dict['region/bbox/xmax'].feature[1].float_list.value[:])
self.assertAllEqual(
[],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def test_make_labeled_example_with_predictions(self): def test_make_labeled_example_with_predictions(self):
num_frames = 2 num_frames = 2
image_height = 100 image_height = 100
......
...@@ -419,7 +419,6 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -419,7 +419,6 @@ def pad_input_data_to_static_shapes(tensor_dict,
max_num_context_features is not specified and context_features is in the max_num_context_features is not specified and context_features is in the
tensor dict. tensor dict.
""" """
if not spatial_image_shape or spatial_image_shape == [-1, -1]: if not spatial_image_shape or spatial_image_shape == [-1, -1]:
height, width = None, None height, width = None, None
else: else:
...@@ -539,11 +538,14 @@ def pad_input_data_to_static_shapes(tensor_dict, ...@@ -539,11 +538,14 @@ def pad_input_data_to_static_shapes(tensor_dict,
padding_shapes[input_fields.context_features] = padding_shape padding_shapes[input_fields.context_features] = padding_shape
tensor_shape = tf.shape( tensor_shape = tf.shape(
tensor_dict[input_fields.context_features]) tensor_dict[fields.InputDataFields.context_features])
tensor_dict[input_fields.valid_context_size] = tensor_shape[0] tensor_dict[fields.InputDataFields.valid_context_size] = tensor_shape[0]
padding_shapes[input_fields.valid_context_size] = [] padding_shapes[fields.InputDataFields.valid_context_size] = []
if input_fields.context_feature_length in tensor_dict: if fields.InputDataFields.context_feature_length in tensor_dict:
padding_shapes[input_fields.context_feature_length] = [] padding_shapes[fields.InputDataFields.context_feature_length] = []
if fields.InputDataFields.context_features_image_id_list in tensor_dict:
padding_shapes[fields.InputDataFields.context_features_image_id_list] = [
max_num_context_features]
if input_fields.is_annotated in tensor_dict: if input_fields.is_annotated in tensor_dict:
padding_shapes[input_fields.is_annotated] = [] padding_shapes[input_fields.is_annotated] = []
...@@ -709,6 +711,9 @@ def _get_features_dict(input_dict, include_source_id=False): ...@@ -709,6 +711,9 @@ def _get_features_dict(input_dict, include_source_id=False):
if fields.InputDataFields.valid_context_size in input_dict: if fields.InputDataFields.valid_context_size in input_dict:
features[fields.InputDataFields.valid_context_size] = input_dict[ features[fields.InputDataFields.valid_context_size] = input_dict[
fields.InputDataFields.valid_context_size] fields.InputDataFields.valid_context_size]
if fields.InputDataFields.context_features_image_id_list in input_dict:
features[fields.InputDataFields.context_features_image_id_list] = (
input_dict[fields.InputDataFields.context_features_image_id_list])
return features return features
......
...@@ -313,6 +313,47 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase): ...@@ -313,6 +313,47 @@ class InputFnTest(test_case.TestCase, parameterized.TestCase):
tf.float32, tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype) labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_eval_input_with_sequence_example_image_id_list(
self, eval_batch_size=8):
"""Tests the eval input function for FasterRcnnResnet50."""
configs = _get_configs_for_model_sequence_example(
'context_rcnn_camera_trap')
model_config = configs['model']
eval_config = configs['eval_config']
eval_config.batch_size = eval_batch_size
eval_input_config = configs['eval_input_configs'][0]
eval_input_config.load_context_image_ids = True
eval_input_fn = inputs.create_eval_input_fn(
eval_config, eval_input_config, model_config)
features, labels = _make_initializable_iterator(eval_input_fn()).get_next()
self.assertAllEqual([eval_batch_size, 640, 640, 3],
features[fields.InputDataFields.image].shape.as_list())
self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
self.assertAllEqual(
[eval_batch_size, 640, 640, 3],
features[fields.InputDataFields.original_image].shape.as_list())
self.assertEqual(tf.uint8,
features[fields.InputDataFields.original_image].dtype)
self.assertAllEqual([eval_batch_size],
features[inputs.HASH_KEY].shape.as_list())
self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
self.assertAllEqual(
[eval_batch_size, 100, 4],
labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_boxes].dtype)
self.assertAllEqual(
[eval_batch_size, 100, model_config.faster_rcnn.num_classes],
labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
self.assertEqual(tf.float32,
labels[fields.InputDataFields.groundtruth_classes].dtype)
self.assertAllEqual(
[eval_batch_size, 100],
labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
self.assertEqual(
tf.float32,
labels[fields.InputDataFields.groundtruth_weights].dtype)
def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index( def test_context_rcnn_resnet50_train_input_with_sequence_example_frame_index(
self, train_batch_size=8): self, train_batch_size=8):
"""Tests the training input function for FasterRcnnResnet50.""" """Tests the training input function for FasterRcnnResnet50."""
......
...@@ -30,7 +30,7 @@ enum InputType { ...@@ -30,7 +30,7 @@ enum InputType {
TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input TF_SEQUENCE_EXAMPLE = 2; // TfSequenceExample Input
} }
// Next id: 36 // Next id: 37
message InputReader { message InputReader {
// Name of input reader. Typically used to describe the dataset that is read // Name of input reader. Typically used to describe the dataset that is read
// by this input reader. // by this input reader.
...@@ -75,7 +75,7 @@ message InputReader { ...@@ -75,7 +75,7 @@ message InputReader {
// to a small constant and increment linearly until the improvements become // to a small constant and increment linearly until the improvements become
// marginal or you exceed your cpu memory budget. Setting this to -1, // marginal or you exceed your cpu memory budget. Setting this to -1,
// automatically tunes this value for you. // automatically tunes this value for you.
optional int32 num_prefetch_batches = 20 [default = -1]; optional int32 num_prefetch_batches = 20 [default = 2];
// Maximum number of records to keep in reader queue. // Maximum number of records to keep in reader queue.
optional uint32 queue_capacity = 3 [default = 2000, deprecated = true]; optional uint32 queue_capacity = 3 [default = 2000, deprecated = true];
...@@ -118,6 +118,9 @@ message InputReader { ...@@ -118,6 +118,9 @@ message InputReader {
// Whether to load context features from the dataset. // Whether to load context features from the dataset.
optional bool load_context_features = 25 [default = false]; optional bool load_context_features = 25 [default = false];
// Whether to load context image ids from the dataset.
optional bool load_context_image_ids = 36 [default = false];
// Whether to load groundtruth instance masks. // Whether to load groundtruth instance masks.
optional bool load_instance_masks = 7 [default = false]; optional bool load_instance_masks = 7 [default = false];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment