Commit 4f14cb62 authored by Derek Chow's avatar Derek Chow
Browse files

Enable inference with dynamic batch size in SSD.

* Creates a new batch_decode method in SSD Meta architecture that can handle
  dynamic batch size.
* use combined_shapes in _get_feature_maps_spatial_dims method to handle
  dynamic batch image_size.
* Add dynamic batch size tests to check preprocess, predict and postprocess
  methods in SSD Meta architecture.
parent 5d5fb7cc
...@@ -13,12 +13,11 @@ py_library( ...@@ -13,12 +13,11 @@ py_library(
srcs = ["ssd_meta_arch.py"], srcs = ["ssd_meta_arch.py"],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/core:box_list", "//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor", "//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:model", "//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:target_assigner", "//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:variables_helper", "//tensorflow_models/object_detection/utils:shape_utils",
], ],
) )
......
...@@ -23,12 +23,12 @@ from abc import abstractmethod ...@@ -23,12 +23,12 @@ from abc import abstractmethod
import re import re
import tensorflow as tf import tensorflow as tf
from object_detection.core import box_coder as bcoder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_predictor as bpredictor from object_detection.core import box_predictor as bpredictor
from object_detection.core import model from object_detection.core import model
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.utils import shape_utils
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -323,7 +323,8 @@ class SSDMetaArch(model.DetectionModel): ...@@ -323,7 +323,8 @@ class SSDMetaArch(model.DetectionModel):
a list of pairs (height, width) for each feature map in feature_maps a list of pairs (height, width) for each feature map in feature_maps
""" """
feature_map_shapes = [ feature_map_shapes = [
feature_map.get_shape().as_list() for feature_map in feature_maps shape_utils.combined_static_and_dynamic_shape(
feature_map) for feature_map in feature_maps
] ]
return [(shape[1], shape[2]) for shape in feature_map_shapes] return [(shape[1], shape[2]) for shape in feature_map_shapes]
...@@ -364,8 +365,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -364,8 +365,7 @@ class SSDMetaArch(model.DetectionModel):
with tf.name_scope('Postprocessor'): with tf.name_scope('Postprocessor'):
box_encodings = prediction_dict['box_encodings'] box_encodings = prediction_dict['box_encodings']
class_predictions = prediction_dict['class_predictions_with_background'] class_predictions = prediction_dict['class_predictions_with_background']
detection_boxes = bcoder.batch_decode(box_encodings, self._box_coder, detection_boxes = self._batch_decode(box_encodings)
self.anchors)
detection_boxes = tf.expand_dims(detection_boxes, axis=2) detection_boxes = tf.expand_dims(detection_boxes, axis=2)
class_predictions_without_background = tf.slice(class_predictions, class_predictions_without_background = tf.slice(class_predictions,
...@@ -549,8 +549,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -549,8 +549,7 @@ class SSDMetaArch(model.DetectionModel):
tf.slice(prediction_dict['class_predictions_with_background'], tf.slice(prediction_dict['class_predictions_with_background'],
[0, 0, 1], class_pred_shape), class_pred_shape) [0, 0, 1], class_pred_shape), class_pred_shape)
decoded_boxes = bcoder.batch_decode(prediction_dict['box_encodings'], decoded_boxes = self._batch_decode(prediction_dict['box_encodings'])
self._box_coder, self.anchors)
decoded_box_tensors_list = tf.unstack(decoded_boxes) decoded_box_tensors_list = tf.unstack(decoded_boxes)
class_prediction_list = tf.unstack(class_predictions) class_prediction_list = tf.unstack(class_predictions)
decoded_boxlist_list = [] decoded_boxlist_list = []
...@@ -565,6 +564,31 @@ class SSDMetaArch(model.DetectionModel): ...@@ -565,6 +564,31 @@ class SSDMetaArch(model.DetectionModel):
decoded_boxlist_list=decoded_boxlist_list, decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list) match_list=match_list)
def _batch_decode(self, box_encodings):
"""Decodes a batch of box encodings with respect to the anchors.
Args:
box_encodings: A float32 tensor of shape
[batch_size, num_anchors, box_code_size] containing box encodings.
Returns:
decoded_boxes: A float32 tensor of shape
[batch_size, num_anchors, 4] containing the decoded boxes.
"""
combined_shape = shape_utils.combined_static_and_dynamic_shape(
box_encodings)
batch_size = combined_shape[0]
tiled_anchor_boxes = tf.tile(
tf.expand_dims(self.anchors.get(), 0), [batch_size, 1, 1])
tiled_anchors_boxlist = box_list.BoxList(
tf.reshape(tiled_anchor_boxes, [-1, self._box_coder.code_size]))
decoded_boxes = self._box_coder.decode(
tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
tiled_anchors_boxlist)
return tf.reshape(decoded_boxes.get(),
tf.stack([combined_shape[0], combined_shape[1],
4]))
def restore_map(self, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint. """Returns a map of variables to load from a foreign checkpoint.
......
...@@ -116,24 +116,46 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -116,24 +116,46 @@ class SsdMetaArchTest(tf.test.TestCase):
localization_loss_weight, normalize_loss_by_num_matches, localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner) hard_example_miner)
def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = self._model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
def test_predict_results_have_correct_keys_and_shapes(self): def test_predict_results_have_correct_keys_and_shapes(self):
batch_size = 3 batch_size = 3
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3), image_size = 2
dtype=tf.float32) input_shapes = [(batch_size, image_size, image_size, 3),
prediction_dict = self._model.predict(preprocessed_input) (None, image_size, image_size, 3),
(batch_size, None, None, 3),
self.assertTrue('box_encodings' in prediction_dict) (None, None, None, 3)]
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)
expected_box_encodings_shape_out = ( expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size) batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = ( expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1) batch_size, self._num_anchors, self._num_classes+1)
init_op = tf.global_variables_initializer()
with self.test_session() as sess: for input_shape in input_shapes:
sess.run(init_op) tf_graph = tf.Graph()
prediction_out = sess.run(prediction_dict) with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)
init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllEqual(prediction_out['box_encodings'].shape, self.assertAllEqual(prediction_out['box_encodings'].shape,
expected_box_encodings_shape_out) expected_box_encodings_shape_out)
self.assertAllEqual( self.assertAllEqual(
...@@ -142,10 +164,11 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -142,10 +164,11 @@ class SsdMetaArchTest(tf.test.TestCase):
def test_postprocess_results_are_correct(self): def test_postprocess_results_are_correct(self):
batch_size = 2 batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3), image_size = 2
dtype=tf.float32) input_shapes = [(batch_size, image_size, image_size, 3),
prediction_dict = self._model.predict(preprocessed_input) (None, image_size, image_size, 3),
detections = self._model.postprocess(prediction_dict) (batch_size, None, None, 3),
(None, None, None, 3)]
expected_boxes = np.array([[[0, 0, .5, .5], expected_boxes = np.array([[[0, 0, .5, .5],
[0, .5, .5, 1], [0, .5, .5, 1],
...@@ -163,15 +186,25 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -163,15 +186,25 @@ class SsdMetaArchTest(tf.test.TestCase):
[0, 0, 0, 0, 0]]) [0, 0, 0, 0, 0]])
expected_num_detections = np.array([4, 4]) expected_num_detections = np.array([4, 4])
self.assertTrue('detection_boxes' in detections) for input_shape in input_shapes:
self.assertTrue('detection_scores' in detections) tf_graph = tf.Graph()
self.assertTrue('detection_classes' in detections) with tf_graph.as_default():
self.assertTrue('num_detections' in detections) preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
init_op = tf.global_variables_initializer() prediction_dict = self._model.predict(preprocessed_input_placeholder)
with self.test_session() as sess: detections = self._model.postprocess(prediction_dict)
sess.run(init_op) self.assertTrue('detection_boxes' in detections)
detections_out = sess.run(detections) self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer()
with self.test_session(graph=tf_graph) as sess:
sess.run(init_op)
detections_out = sess.run(detections,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllClose(detections_out['detection_boxes'], expected_boxes) self.assertAllClose(detections_out['detection_boxes'], expected_boxes)
self.assertAllClose(detections_out['detection_scores'], expected_scores) self.assertAllClose(detections_out['detection_scores'], expected_scores)
self.assertAllClose(detections_out['detection_classes'], expected_classes) self.assertAllClose(detections_out['detection_classes'], expected_classes)
......
...@@ -120,6 +120,7 @@ py_library( ...@@ -120,6 +120,7 @@ py_library(
"//tensorflow_models/object_detection/core:box_list", "//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor", "//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:matcher", "//tensorflow_models/object_detection/core:matcher",
"//tensorflow_models/object_detection/utils:shape_utils"
], ],
) )
......
...@@ -22,6 +22,7 @@ from object_detection.core import box_coder ...@@ -22,6 +22,7 @@ from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_predictor from object_detection.core import box_predictor
from object_detection.core import matcher from object_detection.core import matcher
from object_detection.utils import shape_utils
class MockBoxCoder(box_coder.BoxCoder): class MockBoxCoder(box_coder.BoxCoder):
...@@ -45,9 +46,10 @@ class MockBoxPredictor(box_predictor.BoxPredictor): ...@@ -45,9 +46,10 @@ class MockBoxPredictor(box_predictor.BoxPredictor):
super(MockBoxPredictor, self).__init__(is_training, num_classes) super(MockBoxPredictor, self).__init__(is_training, num_classes)
def _predict(self, image_features, num_predictions_per_location): def _predict(self, image_features, num_predictions_per_location):
batch_size = image_features.get_shape().as_list()[0] combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
num_anchors = (image_features.get_shape().as_list()[1] image_features)
* image_features.get_shape().as_list()[2]) batch_size = combined_feature_shape[0]
num_anchors = (combined_feature_shape[1] * combined_feature_shape[2])
code_size = 4 code_size = 4
zero = tf.reduce_sum(0 * image_features) zero = tf.reduce_sum(0 * image_features)
box_encodings = zero + tf.zeros( box_encodings = zero + tf.zeros(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment