Unverified Commit fd7b6887 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #3293 from pkulzc/master

Internal changes of object_detection 
parents f98ec55e 1efe98bb
...@@ -20,7 +20,7 @@ import tensorflow as tf ...@@ -20,7 +20,7 @@ import tensorflow as tf
from object_detection.core import matcher from object_detection.core import matcher
class AnchorMatcherTest(tf.test.TestCase): class MatchTest(tf.test.TestCase):
def test_get_correct_matched_columnIndices(self): def test_get_correct_matched_columnIndices(self):
match_results = tf.constant([3, 1, -1, 0, -1, 5, -2]) match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
...@@ -145,6 +145,48 @@ class AnchorMatcherTest(tf.test.TestCase): ...@@ -145,6 +145,48 @@ class AnchorMatcherTest(tf.test.TestCase):
self.assertAllEqual(all_indices_sorted, self.assertAllEqual(all_indices_sorted,
np.arange(num_matches, dtype=np.int32)) np.arange(num_matches, dtype=np.int32))
def test_scalar_gather_based_on_match(self):
match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
input_tensor = tf.constant([0, 1, 2, 3, 4, 5, 6, 7], dtype=tf.float32)
expected_gathered_tensor = [3, 1, 100, 0, 100, 5, 200]
match = matcher.Match(match_results)
gathered_tensor = match.gather_based_on_match(input_tensor,
unmatched_value=100.,
ignored_value=200.)
self.assertEquals(gathered_tensor.dtype, tf.float32)
with self.test_session():
gathered_tensor_out = gathered_tensor.eval()
self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
def test_multidimensional_gather_based_on_match(self):
match_results = tf.constant([1, -1, -2])
input_tensor = tf.constant([[0, 0.5, 0, 0.5], [0, 0, 0.5, 0.5]],
dtype=tf.float32)
expected_gathered_tensor = [[0, 0, 0.5, 0.5], [0, 0, 0, 0], [0, 0, 0, 0]]
match = matcher.Match(match_results)
gathered_tensor = match.gather_based_on_match(input_tensor,
unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4))
self.assertEquals(gathered_tensor.dtype, tf.float32)
with self.test_session():
gathered_tensor_out = gathered_tensor.eval()
self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
def test_multidimensional_gather_based_on_match_with_matmul_gather_op(self):
match_results = tf.constant([1, -1, -2])
input_tensor = tf.constant([[0, 0.5, 0, 0.5], [0, 0, 0.5, 0.5]],
dtype=tf.float32)
expected_gathered_tensor = [[0, 0, 0.5, 0.5], [0, 0, 0, 0], [0, 0, 0, 0]]
match = matcher.Match(match_results, use_matmul_gather=True)
gathered_tensor = match.gather_based_on_match(input_tensor,
unmatched_value=tf.zeros(4),
ignored_value=tf.zeros(4))
self.assertEquals(gathered_tensor.dtype, tf.float32)
with self.test_session() as sess:
self.assertTrue(
all([op.name is not 'Gather' for op in sess.graph.get_operations()]))
gathered_tensor_out = gathered_tensor.eval()
self.assertAllEqual(expected_gathered_tensor, gathered_tensor_out)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -39,6 +39,17 @@ resize/reshaping necessary (see docstring for the preprocess function). ...@@ -39,6 +39,17 @@ resize/reshaping necessary (see docstring for the preprocess function).
Output classes are always integers in the range [0, num_classes). Any mapping Output classes are always integers in the range [0, num_classes). Any mapping
of these integers to semantic labels is to be handled outside of this class. of these integers to semantic labels is to be handled outside of this class.
Images are resized in the `preprocess` method. All of `preprocess`, `predict`,
and `postprocess` should be reentrant.
The `preprocess` method runs `image_resizer_fn` that returns resized_images and
`true_image_shapes`. Since `image_resizer_fn` can pad the images with zeros,
true_image_shapes indicate the slices that contain the image without padding.
This is useful for padding images to be a fixed size for batching.
The `postprocess` method uses the true image shapes to clip predictions that lie
outside of images.
By default, DetectionModels produce bounding box detections; However, we support By default, DetectionModels produce bounding box detections; However, we support
a handful of auxiliary annotations associated with each bounding box, namely, a handful of auxiliary annotations associated with each bounding box, namely,
instance masks and keypoints. instance masks and keypoints.
...@@ -106,12 +117,12 @@ class DetectionModel(object): ...@@ -106,12 +117,12 @@ class DetectionModel(object):
This function is responsible for any scaling/shifting of input values that This function is responsible for any scaling/shifting of input values that
is necessary prior to running the detector on an input image. is necessary prior to running the detector on an input image.
It is also responsible for any resizing that might be necessary as images It is also responsible for any resizing, padding that might be necessary
are assumed to arrive in arbitrary sizes. While this function could as images are assumed to arrive in arbitrary sizes. While this function
conceivably be part of the predict method (below), it is often convenient could conceivably be part of the predict method (below), it is often
to keep these separate --- for example, we may want to preprocess on one convenient to keep these separate --- for example, we may want to preprocess
device, place onto a queue, and let another device (e.g., the GPU) handle on one device, place onto a queue, and let another device (e.g., the GPU)
prediction. handle prediction.
A few important notes about the preprocess function: A few important notes about the preprocess function:
+ We assume that this operation does not have any trainable variables nor + We assume that this operation does not have any trainable variables nor
...@@ -134,11 +145,15 @@ class DetectionModel(object): ...@@ -134,11 +145,15 @@ class DetectionModel(object):
Returns: Returns:
preprocessed_inputs: a [batch, height_out, width_out, channels] float32 preprocessed_inputs: a [batch, height_out, width_out, channels] float32
tensor representing a batch of images. tensor representing a batch of images.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
""" """
pass pass
@abstractmethod @abstractmethod
def predict(self, preprocessed_inputs): def predict(self, preprocessed_inputs, true_image_shapes):
"""Predict prediction tensors from inputs tensor. """Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions. Outputs of this function can be passed to loss or postprocess functions.
...@@ -146,6 +161,10 @@ class DetectionModel(object): ...@@ -146,6 +161,10 @@ class DetectionModel(object):
Args: Args:
preprocessed_inputs: a [batch, height, width, channels] float32 tensor preprocessed_inputs: a [batch, height, width, channels] float32 tensor
representing a batch of images. representing a batch of images.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns: Returns:
prediction_dict: a dictionary holding prediction tensors to be prediction_dict: a dictionary holding prediction tensors to be
...@@ -154,7 +173,7 @@ class DetectionModel(object): ...@@ -154,7 +173,7 @@ class DetectionModel(object):
pass pass
@abstractmethod @abstractmethod
def postprocess(self, prediction_dict, **params): def postprocess(self, prediction_dict, true_image_shapes, **params):
"""Convert predicted output tensors to final detections. """Convert predicted output tensors to final detections.
Outputs adhere to the following conventions: Outputs adhere to the following conventions:
...@@ -172,6 +191,10 @@ class DetectionModel(object): ...@@ -172,6 +191,10 @@ class DetectionModel(object):
Args: Args:
prediction_dict: a dictionary holding prediction tensors. prediction_dict: a dictionary holding prediction tensors.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
**params: Additional keyword arguments for specific implementations of **params: Additional keyword arguments for specific implementations of
DetectionModel. DetectionModel.
...@@ -190,7 +213,7 @@ class DetectionModel(object): ...@@ -190,7 +213,7 @@ class DetectionModel(object):
pass pass
@abstractmethod @abstractmethod
def loss(self, prediction_dict): def loss(self, prediction_dict, true_image_shapes):
"""Compute scalar loss tensors with respect to provided groundtruth. """Compute scalar loss tensors with respect to provided groundtruth.
Calling this function requires that groundtruth tensors have been Calling this function requires that groundtruth tensors have been
...@@ -198,6 +221,10 @@ class DetectionModel(object): ...@@ -198,6 +221,10 @@ class DetectionModel(object):
Args: Args:
prediction_dict: a dictionary holding predicted tensors prediction_dict: a dictionary holding predicted tensors
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns: Returns:
a dictionary mapping strings (loss names) to scalar tensors representing a dictionary mapping strings (loss names) to scalar tensors representing
...@@ -209,7 +236,8 @@ class DetectionModel(object): ...@@ -209,7 +236,8 @@ class DetectionModel(object):
groundtruth_boxes_list, groundtruth_boxes_list,
groundtruth_classes_list, groundtruth_classes_list,
groundtruth_masks_list=None, groundtruth_masks_list=None,
groundtruth_keypoints_list=None): groundtruth_keypoints_list=None,
groundtruth_weights_list=None):
"""Provide groundtruth tensors. """Provide groundtruth tensors.
Args: Args:
...@@ -230,10 +258,15 @@ class DetectionModel(object): ...@@ -230,10 +258,15 @@ class DetectionModel(object):
shape [num_boxes, num_keypoints, 2] containing keypoints. shape [num_boxes, num_keypoints, 2] containing keypoints.
Keypoints are assumed to be provided in normalized coordinates and Keypoints are assumed to be provided in normalized coordinates and
missing keypoints should be encoded as NaN. missing keypoints should be encoded as NaN.
groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes.
""" """
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.classes] = groundtruth_classes_list fields.BoxListFields.classes] = groundtruth_classes_list
if groundtruth_weights_list:
self._groundtruth_lists[fields.BoxListFields.
weights] = groundtruth_weights_list
if groundtruth_masks_list: if groundtruth_masks_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.masks] = groundtruth_masks_list fields.BoxListFields.masks] = groundtruth_masks_list
......
This diff is collapsed.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
exports_files([
"pet_label_map.pbtxt",
])
...@@ -120,7 +120,7 @@ def convert_kitti_to_tfrecords(data_dir, output_path, classes_to_use, ...@@ -120,7 +120,7 @@ def convert_kitti_to_tfrecords(data_dir, output_path, classes_to_use,
# Filter all bounding boxes of this frame that are of a legal class, and # Filter all bounding boxes of this frame that are of a legal class, and
# don't overlap with a dontcare region. # don't overlap with a dontcare region.
# TODO(talremez) filter out targets that are truncated or heavily occluded. # TODO filter out targets that are truncated or heavily occluded.
annotation_for_image = filter_annotations(img_anno, classes_to_use) annotation_for_image = filter_annotations(img_anno, classes_to_use)
example = prepare_example(image_path, annotation_for_image, label_map_dict) example = prepare_example(image_path, annotation_for_image, label_map_dict)
......
...@@ -24,7 +24,7 @@ import tensorflow as tf ...@@ -24,7 +24,7 @@ import tensorflow as tf
from object_detection.dataset_tools import create_kitti_tf_record from object_detection.dataset_tools import create_kitti_tf_record
class DictToTFExampleTest(tf.test.TestCase): class CreateKittiTFRecordTest(tf.test.TestCase):
def _assertProtoEqual(self, proto_field, expectation): def _assertProtoEqual(self, proto_field, expectation):
"""Helper function to assert if a proto field equals some value. """Helper function to assert if a proto field equals some value.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment