Commit 582bf927 authored by derekjchow's avatar derekjchow Committed by GitHub
Browse files

Merge pull request #2053 from derekjchow/master

object_detection exporter updates
parents ecf5edf1 a2cb67c2
...@@ -270,6 +270,7 @@ py_library( ...@@ -270,6 +270,7 @@ py_library(
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape", "//tensorflow_models/object_detection/utils:static_shape",
], ],
) )
......
...@@ -29,6 +29,7 @@ few box predictor architectures are shared across many models. ...@@ -29,6 +29,7 @@ few box predictor architectures are shared across many models.
from abc import abstractmethod from abc import abstractmethod
import tensorflow as tf import tensorflow as tf
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import shape_utils
from object_detection.utils import static_shape from object_detection.utils import static_shape
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -316,6 +317,8 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -316,6 +317,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
self._predict_instance_masks = predict_instance_masks self._predict_instance_masks = predict_instance_masks
self._mask_prediction_conv_depth = mask_prediction_conv_depth self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._predict_keypoints = predict_keypoints self._predict_keypoints = predict_keypoints
if self._predict_instance_masks:
raise ValueError('Mask prediction is unimplemented.')
if self._predict_keypoints: if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.') raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and if ((self._predict_instance_masks or self._predict_keypoints) and
...@@ -524,23 +527,21 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -524,23 +527,21 @@ class ConvolutionalBoxPredictor(BoxPredictor):
class_predictions_with_background = tf.sigmoid( class_predictions_with_background = tf.sigmoid(
class_predictions_with_background) class_predictions_with_background)
batch_size = static_shape.get_batch_size(image_features.get_shape()) combined_feature_map_shape = shape_utils.combined_static_and_dynamic_shape(
if batch_size is None: image_features)
features_height = static_shape.get_height(image_features.get_shape())
features_width = static_shape.get_width(image_features.get_shape())
flattened_predictions_size = (features_height * features_width *
num_predictions_per_location)
box_encodings = tf.reshape( box_encodings = tf.reshape(
box_encodings, box_encodings, tf.stack([combined_feature_map_shape[0],
[-1, flattened_predictions_size, 1, self._box_code_size]) combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
1, self._box_code_size]))
class_predictions_with_background = tf.reshape( class_predictions_with_background = tf.reshape(
class_predictions_with_background, class_predictions_with_background,
[-1, flattened_predictions_size, num_class_slots]) tf.stack([combined_feature_map_shape[0],
else: combined_feature_map_shape[1] *
box_encodings = tf.reshape( combined_feature_map_shape[2] *
box_encodings, [batch_size, -1, 1, self._box_code_size]) num_predictions_per_location,
class_predictions_with_background = tf.reshape( num_class_slots]))
class_predictions_with_background, [batch_size, -1, num_class_slots])
return {BOX_ENCODINGS: box_encodings, return {BOX_ENCODINGS: box_encodings,
CLASS_PREDICTIONS_WITH_BACKGROUND: CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background} class_predictions_with_background}
...@@ -228,25 +228,24 @@ class DetectionModel(object): ...@@ -228,25 +228,24 @@ class DetectionModel(object):
fields.BoxListFields.keypoints] = groundtruth_keypoints_list fields.BoxListFields.keypoints] = groundtruth_keypoints_list
@abstractmethod @abstractmethod
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a foreign checkpoint into tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Loads variables from a different tensorflow graph (typically feature Returns a map of variable names to load from a checkpoint to variables in
extractor variables). This enables the model to initialize based on weights the model graph. This enables the model to initialize based on weights from
from another task. For example, the feature extractor variables from a another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of should have the same parameters as this detection model with exception of
the num_classes parameter. the num_classes parameter.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
pass pass
...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -174,7 +174,8 @@ def batch_multiclass_non_max_suppression(boxes,
change_coordinate_frame=False, change_coordinate_frame=False,
num_valid_boxes=None, num_valid_boxes=None,
masks=None, masks=None,
scope=None): scope=None,
parallel_iterations=32):
"""Multi-class version of non maximum suppression that operates on a batch. """Multi-class version of non maximum suppression that operates on a batch.
This op is similar to `multiclass_non_max_suppression` but operates on a batch This op is similar to `multiclass_non_max_suppression` but operates on a batch
...@@ -208,26 +209,28 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -208,26 +209,28 @@ def batch_multiclass_non_max_suppression(boxes,
float32 tensor containing box masks. `q` can be either number of classes float32 tensor containing box masks. `q` can be either number of classes
or 1 depending on whether a separate mask is predicted per class. or 1 depending on whether a separate mask is predicted per class.
scope: tf scope name. scope: tf scope name.
parallel_iterations: (optional) number of batch items to process in
parallel.
Returns: Returns:
A dictionary containing the following entries: 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
containing the non-max suppressed boxes. containing the non-max suppressed boxes.
'detection_scores': A [bath_size, max_detections] float32 tensor containing 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing
the scores for the boxes. the scores for the boxes.
'detection_classes': A [batch_size, max_detections] float32 tensor 'nmsed_classes': A [batch_size, max_detections] float32 tensor
containing the class for boxes. containing the class for boxes.
'num_detections': A [batchsize] float32 tensor indicating the number of 'nmsed_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box. This is set to None if input
`masks` is None.
'num_detections': A [batch_size] int32 tensor indicating the number of
valid detections per batch item. Only the top num_detections[i] entries in valid detections per batch item. Only the top num_detections[i] entries in
nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
entries are zero paddings. entries are zero paddings.
'detection_masks': (optional) a
[batch_size, max_detections, mask_height, mask_width] float32 tensor
containing masks for each selected box.
Raises: Raises:
ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have ValueError: if `q` in boxes.shape is not 1 or not equal to number of
a valid scores field. classes as inferred from scores.shape.
""" """
q = boxes.shape[2].value q = boxes.shape[2].value
num_classes = scores.shape[2].value num_classes = scores.shape[2].value
...@@ -235,36 +238,45 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -235,36 +238,45 @@ def batch_multiclass_non_max_suppression(boxes,
raise ValueError('third dimension of boxes must be either 1 or equal ' raise ValueError('third dimension of boxes must be either 1 or equal '
'to the third dimension of scores') 'to the third dimension of scores')
original_masks = masks
with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'): with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
per_image_boxes_list = tf.unstack(boxes) boxes_shape = boxes.shape
per_image_scores_list = tf.unstack(scores) batch_size = boxes_shape[0].value
num_valid_boxes_list = len(per_image_boxes_list) * [None] num_anchors = boxes_shape[1].value
per_image_masks_list = len(per_image_boxes_list) * [None]
if num_valid_boxes is not None: if batch_size is None:
num_valid_boxes_list = tf.unstack(num_valid_boxes) batch_size = tf.shape(boxes)[0]
if masks is not None: if num_anchors is None:
per_image_masks_list = tf.unstack(masks) num_anchors = tf.shape(boxes)[1]
# If num valid boxes aren't provided, create one and mark all boxes as
# valid.
if num_valid_boxes is None:
num_valid_boxes = tf.ones([batch_size], dtype=tf.int32) * num_anchors
detection_boxes_list = [] # If masks aren't provided, create dummy masks so we can only have one copy
detection_scores_list = [] # of single_image_nms_fn and discard the dummy masks after map_fn.
detection_classes_list = [] if masks is None:
num_detections_list = [] masks_shape = tf.stack([batch_size, num_anchors, 1, 0, 0])
detection_masks_list = [] masks = tf.zeros(masks_shape)
for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
) in zip(per_image_boxes_list, per_image_scores_list, def single_image_nms_fn(args):
per_image_masks_list, num_valid_boxes_list): """Runs NMS on a single image and returns padded output."""
if num_valid_boxes is not None: (per_image_boxes, per_image_scores, per_image_masks,
per_image_num_valid_boxes) = args
per_image_boxes = tf.reshape( per_image_boxes = tf.reshape(
tf.slice(per_image_boxes, 3*[0], tf.slice(per_image_boxes, 3 * [0],
tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4]) tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
per_image_scores = tf.reshape( per_image_scores = tf.reshape(
tf.slice(per_image_scores, [0, 0], tf.slice(per_image_scores, [0, 0],
tf.stack([num_valid_boxes, -1])), [-1, num_classes]) tf.stack([per_image_num_valid_boxes, -1])),
if masks is not None: [-1, num_classes])
per_image_masks = tf.reshape( per_image_masks = tf.reshape(
tf.slice(per_image_masks, 4*[0], tf.slice(per_image_masks, 4 * [0],
tf.stack([num_valid_boxes, -1, -1, -1])), tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
[-1, q, masks.shape[3].value, masks.shape[4].value]) [-1, q, per_image_masks.shape[2].value,
per_image_masks.shape[3].value])
nmsed_boxlist = multiclass_non_max_suppression( nmsed_boxlist = multiclass_non_max_suppression(
per_image_boxes, per_image_boxes,
per_image_scores, per_image_scores,
...@@ -275,24 +287,26 @@ def batch_multiclass_non_max_suppression(boxes, ...@@ -275,24 +287,26 @@ def batch_multiclass_non_max_suppression(boxes,
masks=per_image_masks, masks=per_image_masks,
clip_window=clip_window, clip_window=clip_window,
change_coordinate_frame=change_coordinate_frame) change_coordinate_frame=change_coordinate_frame)
num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes()))
padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist, padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
max_total_size) max_total_size)
detection_boxes_list.append(padded_boxlist.get()) num_detections = nmsed_boxlist.num_boxes()
detection_scores_list.append( nmsed_boxes = padded_boxlist.get()
padded_boxlist.get_field(fields.BoxListFields.scores)) nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
detection_classes_list.append( nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
padded_boxlist.get_field(fields.BoxListFields.classes)) nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
if masks is not None: return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
detection_masks_list.append( num_detections]
padded_boxlist.get_field(fields.BoxListFields.masks))
nms_dict = { (batch_nmsed_boxes, batch_nmsed_scores,
'detection_boxes': tf.stack(detection_boxes_list), batch_nmsed_classes, batch_nmsed_masks,
'detection_scores': tf.stack(detection_scores_list), batch_num_detections) = tf.map_fn(
'detection_classes': tf.stack(detection_classes_list), single_image_nms_fn,
'num_detections': tf.stack(num_detections_list) elems=[boxes, scores, masks, num_valid_boxes],
} dtype=[tf.float32, tf.float32, tf.float32, tf.float32, tf.int32],
if masks is not None: parallel_iterations=parallel_iterations)
nms_dict['detection_masks'] = tf.stack(detection_masks_list)
return nms_dict if original_masks is None:
batch_nmsed_masks = None
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_nmsed_masks, batch_num_detections)
...@@ -496,15 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -496,15 +496,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
exp_nms_scores = [[.95, .9, .85, .3]] exp_nms_scores = [[.95, .9, .85, .3]]
exp_nms_classes = [[0, 0, 1, 0]] exp_nms_classes = [[0, 0, 1, 0]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size) max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertEqual(nms_output['num_detections'], [4]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def test_batch_multiclass_nms_with_batch_size_2(self): def test_batch_multiclass_nms_with_batch_size_2(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -524,28 +530,42 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -524,28 +530,42 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]], [0, 0, 0, 0]],
[[0, 999, 2, 1004], [[0, 999, 2, 1004],
[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1],
[0, 100, 1, 101], [0, 100, 1, 101],
[0, 0, 0, 0]]] [0, 0, 0, 0]]])
exp_nms_scores = [[.95, .9, 0, 0], exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]] [.85, .5, .3, 0]])
exp_nms_classes = [[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]] [1, 0, 0, 0]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size) max_size_per_class=max_output_size, max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_masks(self): def test_batch_multiclass_nms_with_masks(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -574,38 +594,126 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -574,38 +594,126 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
iou_thresh = .5 iou_thresh = .5
max_output_size = 4 max_output_size = 4
exp_nms_corners = [[[0, 10, 1, 11], exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1], [0, 0, 1, 1],
[0, 0, 0, 0], [0, 0, 0, 0],
[0, 0, 0, 0]], [0, 0, 0, 0]],
[[0, 999, 2, 1004], [[0, 999, 2, 1004],
[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1],
[0, 100, 1, 101], [0, 100, 1, 101],
[0, 0, 0, 0]]] [0, 0, 0, 0]]])
exp_nms_scores = [[.95, .9, 0, 0], exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]] [.85, .5, .3, 0]])
exp_nms_classes = [[0, 0, 0, 0], exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]] [1, 0, 0, 0]])
exp_nms_masks = [[[[6, 7], [8, 9]], exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]], [[0, 1], [2, 3]],
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]], [[0, 0], [0, 0]]],
[[[13, 14], [15, 16]], [[[13, 14], [15, 16]],
[[8, 9], [10, 11]], [[8, 9], [10, 11]],
[[10, 11], [12, 13]], [[10, 11], [12, 13]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]])
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size, max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks) masks=masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes)
self.assertAllClose(nms_output['num_detections'], [2, 3]) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 2, 2))
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
masks = np.array([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), [None, 4, 4])
self.assertAllEqual(nmsed_scores.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_classes.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_masks.shape.as_list(), [None, 4, 2, 2])
self.assertEqual(num_detections.shape.as_list(), [None])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections],
feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores,
masks_placeholder: masks})
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self): def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]], boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
...@@ -656,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase): ...@@ -656,17 +764,21 @@ class MulticlassNonMaxSuppressionTest(tf.test.TestCase):
[[0, 0], [0, 0]], [[0, 0], [0, 0]],
[[0, 0], [0, 0]]]] [[0, 0], [0, 0]]]]
nms_dict = post_processing.batch_multiclass_non_max_suppression( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh, boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size, max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks) num_valid_boxes=num_valid_boxes, masks=masks)
with self.test_session() as sess: with self.test_session() as sess:
nms_output = sess.run(nms_dict) (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
self.assertAllClose(nms_output['detection_boxes'], exp_nms_corners) num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
self.assertAllClose(nms_output['detection_scores'], exp_nms_scores) nmsed_masks, num_detections])
self.assertAllClose(nms_output['detection_classes'], exp_nms_classes) self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nms_output['num_detections'], [1, 1]) self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nms_output['detection_masks'], exp_nms_masks) self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1255,6 +1255,82 @@ def random_resize_method(image, target_size): ...@@ -1255,6 +1255,82 @@ def random_resize_method(image, target_size):
return resized_image return resized_image
def _compute_new_static_size(image,
min_dimension,
max_dimension):
"""Compute new static shape for resize_to_range method."""
image_shape = image.get_shape().as_list()
orig_height = image_shape[0]
orig_width = image_shape[1]
orig_min_dim = min(orig_height, orig_width)
# Calculates the larger of the possible sizes
large_scale_factor = min_dimension / float(orig_min_dim)
# Scaling orig_(height|width) by large_scale_factor will make the smaller
# dimension equal to min_dimension, save for floating point rounding errors.
# For reasonably-sized images, taking the nearest integer will reliably
# eliminate this error.
large_height = int(round(orig_height * large_scale_factor))
large_width = int(round(orig_width * large_scale_factor))
large_size = [large_height, large_width]
if max_dimension:
# Calculates the smaller of the possible sizes, use that if the larger
# is too big.
orig_max_dim = max(orig_height, orig_width)
small_scale_factor = max_dimension / float(orig_max_dim)
# Scaling orig_(height|width) by small_scale_factor will make the larger
# dimension equal to max_dimension, save for floating point rounding
# errors. For reasonably-sized images, taking the nearest integer will
# reliably eliminate this error.
small_height = int(round(orig_height * small_scale_factor))
small_width = int(round(orig_width * small_scale_factor))
small_size = [small_height, small_width]
new_size = large_size
if max(large_size) > max_dimension:
new_size = small_size
else:
new_size = large_size
return tf.constant(new_size)
def _compute_new_dynamic_size(image,
min_dimension,
max_dimension):
"""Compute new dynamic shape for resize_to_range method."""
image_shape = tf.shape(image)
orig_height = tf.to_float(image_shape[0])
orig_width = tf.to_float(image_shape[1])
orig_min_dim = tf.minimum(orig_height, orig_width)
# Calculates the larger of the possible sizes
min_dimension = tf.constant(min_dimension, dtype=tf.float32)
large_scale_factor = min_dimension / orig_min_dim
# Scaling orig_(height|width) by large_scale_factor will make the smaller
# dimension equal to min_dimension, save for floating point rounding errors.
# For reasonably-sized images, taking the nearest integer will reliably
# eliminate this error.
large_height = tf.to_int32(tf.round(orig_height * large_scale_factor))
large_width = tf.to_int32(tf.round(orig_width * large_scale_factor))
large_size = tf.stack([large_height, large_width])
if max_dimension:
# Calculates the smaller of the possible sizes, use that if the larger
# is too big.
orig_max_dim = tf.maximum(orig_height, orig_width)
max_dimension = tf.constant(max_dimension, dtype=tf.float32)
small_scale_factor = max_dimension / orig_max_dim
# Scaling orig_(height|width) by small_scale_factor will make the larger
# dimension equal to max_dimension, save for floating point rounding
# errors. For reasonably-sized images, taking the nearest integer will
# reliably eliminate this error.
small_height = tf.to_int32(tf.round(orig_height * small_scale_factor))
small_width = tf.to_int32(tf.round(orig_width * small_scale_factor))
small_size = tf.stack([small_height, small_width])
new_size = tf.cond(
tf.to_float(tf.reduce_max(large_size)) > max_dimension,
lambda: small_size, lambda: large_size)
else:
new_size = large_size
return new_size
def resize_to_range(image, def resize_to_range(image,
masks=None, masks=None,
min_dimension=None, min_dimension=None,
...@@ -1295,64 +1371,22 @@ def resize_to_range(image, ...@@ -1295,64 +1371,22 @@ def resize_to_range(image,
raise ValueError('Image should be 3D tensor') raise ValueError('Image should be 3D tensor')
with tf.name_scope('ResizeToRange', values=[image, min_dimension]): with tf.name_scope('ResizeToRange', values=[image, min_dimension]):
image_shape = tf.shape(image) if image.get_shape().is_fully_defined():
orig_height = tf.to_float(image_shape[0]) new_size = _compute_new_static_size(image, min_dimension,
orig_width = tf.to_float(image_shape[1]) max_dimension)
orig_min_dim = tf.minimum(orig_height, orig_width)
# Calculates the larger of the possible sizes
min_dimension = tf.constant(min_dimension, dtype=tf.float32)
large_scale_factor = min_dimension / orig_min_dim
# Scaling orig_(height|width) by large_scale_factor will make the smaller
# dimension equal to min_dimension, save for floating point rounding errors.
# For reasonably-sized images, taking the nearest integer will reliably
# eliminate this error.
large_height = tf.to_int32(tf.round(orig_height * large_scale_factor))
large_width = tf.to_int32(tf.round(orig_width * large_scale_factor))
large_size = tf.stack([large_height, large_width])
if max_dimension:
# Calculates the smaller of the possible sizes, use that if the larger
# is too big.
orig_max_dim = tf.maximum(orig_height, orig_width)
max_dimension = tf.constant(max_dimension, dtype=tf.float32)
small_scale_factor = max_dimension / orig_max_dim
# Scaling orig_(height|width) by small_scale_factor will make the larger
# dimension equal to max_dimension, save for floating point rounding
# errors. For reasonably-sized images, taking the nearest integer will
# reliably eliminate this error.
small_height = tf.to_int32(tf.round(orig_height * small_scale_factor))
small_width = tf.to_int32(tf.round(orig_width * small_scale_factor))
small_size = tf.stack([small_height, small_width])
new_size = tf.cond(
tf.to_float(tf.reduce_max(large_size)) > max_dimension,
lambda: small_size, lambda: large_size)
else: else:
new_size = large_size new_size = _compute_new_dynamic_size(image, min_dimension,
max_dimension)
new_image = tf.image.resize_images(image, new_size, new_image = tf.image.resize_images(image, new_size,
align_corners=align_corners) align_corners=align_corners)
result = new_image result = new_image
if masks is not None: if masks is not None:
num_instances = tf.shape(masks)[0]
def resize_masks_branch():
new_masks = tf.expand_dims(masks, 3) new_masks = tf.expand_dims(masks, 3)
new_masks = tf.image.resize_nearest_neighbor( new_masks = tf.image.resize_nearest_neighbor(new_masks, new_size,
new_masks, new_size, align_corners=align_corners) align_corners=align_corners)
new_masks = tf.squeeze(new_masks, axis=3) new_masks = tf.squeeze(new_masks, 3)
return new_masks result = [new_image, new_masks]
def reshape_masks_branch():
new_masks = tf.reshape(masks, [0, new_size[0], new_size[1]])
return new_masks
masks = tf.cond(num_instances > 0,
resize_masks_branch,
reshape_masks_branch)
result = [new_image, masks]
return result return result
......
...@@ -1395,7 +1395,7 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1395,7 +1395,7 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllEqual(expected_images_shape_, self.assertAllEqual(expected_images_shape_,
resized_images_shape_) resized_images_shape_)
def testResizeToRange(self): def testResizeToRangePreservesStaticSpatialShape(self):
"""Tests image resizing, checking output sizes.""" """Tests image resizing, checking output sizes."""
in_shape_list = [[60, 40, 3], [15, 30, 3], [15, 50, 3]] in_shape_list = [[60, 40, 3], [15, 30, 3], [15, 50, 3]]
min_dim = 50 min_dim = 50
...@@ -1406,13 +1406,27 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1406,13 +1406,27 @@ class PreprocessorTest(tf.test.TestCase):
in_image = tf.random_uniform(in_shape) in_image = tf.random_uniform(in_shape)
out_image = preprocessor.resize_to_range( out_image = preprocessor.resize_to_range(
in_image, min_dimension=min_dim, max_dimension=max_dim) in_image, min_dimension=min_dim, max_dimension=max_dim)
out_image_shape = tf.shape(out_image) self.assertAllEqual(out_image.get_shape().as_list(), expected_shape)
def testResizeToRangeWithDynamicSpatialShape(self):
"""Tests image resizing, checking output sizes."""
in_shape_list = [[60, 40, 3], [15, 30, 3], [15, 50, 3]]
min_dim = 50
max_dim = 100
expected_shape_list = [[75, 50, 3], [50, 100, 3], [30, 100, 3]]
for in_shape, expected_shape in zip(in_shape_list, expected_shape_list):
in_image = tf.placeholder(tf.float32, shape=(None, None, 3))
out_image = preprocessor.resize_to_range(
in_image, min_dimension=min_dim, max_dimension=max_dim)
out_image_shape = tf.shape(out_image)
with self.test_session() as sess: with self.test_session() as sess:
out_image_shape = sess.run(out_image_shape) out_image_shape = sess.run(out_image_shape,
feed_dict={in_image:
np.random.randn(*in_shape)})
self.assertAllEqual(out_image_shape, expected_shape) self.assertAllEqual(out_image_shape, expected_shape)
def testResizeToRangeWithMasks(self): def testResizeToRangeWithMasksPreservesStaticSpatialShape(self):
"""Tests image resizing, checking output sizes.""" """Tests image resizing, checking output sizes."""
in_image_shape_list = [[60, 40, 3], [15, 30, 3]] in_image_shape_list = [[60, 40, 3], [15, 30, 3]]
in_masks_shape_list = [[15, 60, 40], [10, 15, 30]] in_masks_shape_list = [[15, 60, 40], [10, 15, 30]]
...@@ -1430,30 +1444,25 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1430,30 +1444,25 @@ class PreprocessorTest(tf.test.TestCase):
in_masks = tf.random_uniform(in_masks_shape) in_masks = tf.random_uniform(in_masks_shape)
out_image, out_masks = preprocessor.resize_to_range( out_image, out_masks = preprocessor.resize_to_range(
in_image, in_masks, min_dimension=min_dim, max_dimension=max_dim) in_image, in_masks, min_dimension=min_dim, max_dimension=max_dim)
out_image_shape = tf.shape(out_image) self.assertAllEqual(out_masks.get_shape().as_list(), expected_mask_shape)
out_masks_shape = tf.shape(out_masks) self.assertAllEqual(out_image.get_shape().as_list(), expected_image_shape)
with self.test_session() as sess:
out_image_shape, out_masks_shape = sess.run(
[out_image_shape, out_masks_shape])
self.assertAllEqual(out_image_shape, expected_image_shape)
self.assertAllEqual(out_masks_shape, expected_mask_shape)
def testResizeToRangeWithNoInstanceMask(self): def testResizeToRangeWithMasksAndDynamicSpatialShape(self):
"""Tests image resizing, checking output sizes.""" """Tests image resizing, checking output sizes."""
in_image_shape_list = [[60, 40, 3], [15, 30, 3]] in_image_shape_list = [[60, 40, 3], [15, 30, 3]]
in_masks_shape_list = [[0, 60, 40], [0, 15, 30]] in_masks_shape_list = [[15, 60, 40], [10, 15, 30]]
min_dim = 50 min_dim = 50
max_dim = 100 max_dim = 100
expected_image_shape_list = [[75, 50, 3], [50, 100, 3]] expected_image_shape_list = [[75, 50, 3], [50, 100, 3]]
expected_masks_shape_list = [[0, 75, 50], [0, 50, 100]] expected_masks_shape_list = [[15, 75, 50], [10, 50, 100]]
for (in_image_shape, expected_image_shape, in_masks_shape, for (in_image_shape, expected_image_shape, in_masks_shape,
expected_mask_shape) in zip(in_image_shape_list, expected_mask_shape) in zip(in_image_shape_list,
expected_image_shape_list, expected_image_shape_list,
in_masks_shape_list, in_masks_shape_list,
expected_masks_shape_list): expected_masks_shape_list):
in_image = tf.random_uniform(in_image_shape) in_image = tf.placeholder(tf.float32, shape=(None, None, 3))
in_masks = tf.placeholder(tf.float32, shape=(None, None, None))
in_masks = tf.random_uniform(in_masks_shape) in_masks = tf.random_uniform(in_masks_shape)
out_image, out_masks = preprocessor.resize_to_range( out_image, out_masks = preprocessor.resize_to_range(
in_image, in_masks, min_dimension=min_dim, max_dimension=max_dim) in_image, in_masks, min_dimension=min_dim, max_dimension=max_dim)
...@@ -1462,38 +1471,15 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1462,38 +1471,15 @@ class PreprocessorTest(tf.test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
out_image_shape, out_masks_shape = sess.run( out_image_shape, out_masks_shape = sess.run(
[out_image_shape, out_masks_shape]) [out_image_shape, out_masks_shape],
self.assertAllEqual(out_image_shape, expected_image_shape) feed_dict={
self.assertAllEqual(out_masks_shape, expected_mask_shape) in_image: np.random.randn(*in_image_shape),
in_masks: np.random.randn(*in_masks_shape)
def testResizeImageWithMasks(self): })
"""Tests image resizing, checking output sizes."""
in_image_shape_list = [[60, 40, 3], [15, 30, 3]]
in_masks_shape_list = [[15, 60, 40], [10, 15, 30]]
height = 50
width = 100
expected_image_shape_list = [[50, 100, 3], [50, 100, 3]]
expected_masks_shape_list = [[15, 50, 100], [10, 50, 100]]
for (in_image_shape, expected_image_shape, in_masks_shape,
expected_mask_shape) in zip(in_image_shape_list,
expected_image_shape_list,
in_masks_shape_list,
expected_masks_shape_list):
in_image = tf.random_uniform(in_image_shape)
in_masks = tf.random_uniform(in_masks_shape)
out_image, out_masks = preprocessor.resize_image(
in_image, in_masks, new_height=height, new_width=width)
out_image_shape = tf.shape(out_image)
out_masks_shape = tf.shape(out_masks)
with self.test_session() as sess:
out_image_shape, out_masks_shape = sess.run(
[out_image_shape, out_masks_shape])
self.assertAllEqual(out_image_shape, expected_image_shape) self.assertAllEqual(out_image_shape, expected_image_shape)
self.assertAllEqual(out_masks_shape, expected_mask_shape) self.assertAllEqual(out_masks_shape, expected_mask_shape)
def testResizeImageWithNoInstanceMask(self): def testResizeToRangeWithInstanceMasksTensorOfSizeZero(self):
"""Tests image resizing, checking output sizes.""" """Tests image resizing, checking output sizes."""
in_image_shape_list = [[60, 40, 3], [15, 30, 3]] in_image_shape_list = [[60, 40, 3], [15, 30, 3]]
in_masks_shape_list = [[0, 60, 40], [0, 15, 30]] in_masks_shape_list = [[0, 60, 40], [0, 15, 30]]
......
...@@ -16,16 +16,19 @@ ...@@ -16,16 +16,19 @@
r"""Tool to export an object detection model for inference. r"""Tool to export an object detection model for inference.
Prepares an object detection tensorflow graph for inference using model Prepares an object detection tensorflow graph for inference using model
configuration and an optional trained checkpoint. Outputs either an inference configuration and an optional trained checkpoint. Outputs inference
graph or a SavedModel (https://tensorflow.github.io/serving/serving_basic.html). graph, associated checkpoint files, a frozen inference graph and a
SavedModel (https://tensorflow.github.io/serving/serving_basic.html).
The inference graph contains one of three input nodes depending on the user The inference graph contains one of three input nodes depending on the user
specified option. specified option.
* `image_tensor`: Accepts a uint8 4-D tensor of shape [1, None, None, 3] * `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3]
* `encoded_image_string_tensor`: Accepts a scalar string tensor of encoded PNG * `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None]
or JPEG image. containing encoded PNG or JPEG images. Image resolutions are expected to be
* `tf_example`: Accepts a serialized TFExample proto. The batch size in this the same if more than 1 image is provided.
case is always 1. * `tf_example`: Accepts a 1-D string tensor of shape [None] containing
serialized TFExample protos. Image resolutions are expected to be the same
if more than 1 image is provided.
and the following output nodes returned by the model.postprocess(..): and the following output nodes returned by the model.postprocess(..):
* `num_detections`: Outputs float32 tensors of the form [batch] * `num_detections`: Outputs float32 tensors of the form [batch]
...@@ -41,23 +44,27 @@ and the following output nodes returned by the model.postprocess(..): ...@@ -41,23 +44,27 @@ and the following output nodes returned by the model.postprocess(..):
masks for each box if its present in the dictionary of postprocessed masks for each box if its present in the dictionary of postprocessed
tensors returned by the model. tensors returned by the model.
Note that currently `batch` is always 1, but we will support `batch` > 1 in Notes:
the future. * This tool uses `use_moving_averages` from eval_config to decide which
weights to freeze.
Optionally, one can freeze the graph by converting the weights in the provided
checkpoint as graph constants thereby eliminating the need to use a checkpoint
file during inference.
Note that this tool uses `use_moving_averages` from eval_config to decide
which weights to freeze.
Example Usage: Example Usage:
-------------- --------------
python export_inference_graph \ python export_inference_graph \
--input_type image_tensor \ --input_type image_tensor \
--pipeline_config_path path/to/ssd_inception_v2.config \ --pipeline_config_path path/to/ssd_inception_v2.config \
--checkpoint_path path/to/model-ckpt \ --trained_checkpoint_prefix path/to/model.ckpt \
--inference_graph_path path/to/inference_graph.pb --output_directory path/to/exported_model_directory
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- graph.pbtxt
- model.ckpt.data-00000-of-00001
- model.ckpt.info
- model.ckpt.meta
- frozen_inference_graph.pb
+ saved_model (a directory)
""" """
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -70,31 +77,29 @@ flags = tf.app.flags ...@@ -70,31 +77,29 @@ flags = tf.app.flags
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ' flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor`, `encoded_image_string_tensor`, ' 'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`]') '`tf_example`]')
flags.DEFINE_string('pipeline_config_path', '', flags.DEFINE_string('pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.') 'file.')
flags.DEFINE_string('checkpoint_path', '', 'Optional path to checkpoint file. ' flags.DEFINE_string('trained_checkpoint_prefix', None,
'If provided, bakes the weights from the checkpoint into ' 'Path to trained checkpoint, typically of the form '
'the graph.') 'path/to/model.ckpt')
flags.DEFINE_string('inference_graph_path', '', 'Path to write the output ' flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
'inference graph.')
flags.DEFINE_bool('export_as_saved_model', False, 'Whether the exported graph '
'should be saved as a SavedModel')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def main(_): def main(_):
assert FLAGS.pipeline_config_path, 'TrainEvalPipelineConfig missing.' assert FLAGS.pipeline_config_path, '`pipeline_config_path` is missing'
assert FLAGS.inference_graph_path, 'Inference graph path missing.' assert FLAGS.trained_checkpoint_prefix, (
assert FLAGS.input_type, 'Input type missing.' '`trained_checkpoint_prefix` is missing')
assert FLAGS.output_directory, '`output_directory` is missing'
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config) text_format.Merge(f.read(), pipeline_config)
exporter.export_inference_graph(FLAGS.input_type, pipeline_config, exporter.export_inference_graph(
FLAGS.checkpoint_path, FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix,
FLAGS.inference_graph_path, FLAGS.output_directory)
FLAGS.export_as_saved_model)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
import os import os
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
...@@ -42,6 +43,7 @@ def freeze_graph_with_def_protos( ...@@ -42,6 +43,7 @@ def freeze_graph_with_def_protos(
filename_tensor_name, filename_tensor_name,
clear_devices, clear_devices,
initializer_nodes, initializer_nodes,
optimize_graph=False,
variable_names_blacklist=''): variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants.""" """Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code. del restore_op_name, filename_tensor_name # Unused by updated loading code.
...@@ -61,9 +63,23 @@ def freeze_graph_with_def_protos( ...@@ -61,9 +63,23 @@ def freeze_graph_with_def_protos(
for node in input_graph_def.node: for node in input_graph_def.node:
node.device = '' node.device = ''
_ = importer.import_graph_def(input_graph_def, name='') with tf.Graph().as_default():
tf.import_graph_def(input_graph_def, name='')
with session.Session() as sess:
if optimize_graph:
logging.info('Graph Rewriter optimizations enabled')
rewrite_options = rewriter_config_pb2.RewriterConfig(
optimize_tensor_layout=True)
rewrite_options.optimizers.append('pruning')
rewrite_options.optimizers.append('constfold')
rewrite_options.optimizers.append('layout')
graph_options = tf.GraphOptions(
rewrite_options=rewrite_options, infer_shapes=True)
else:
logging.info('Graph Rewriter optimizations disabled')
graph_options = tf.GraphOptions()
config = tf.ConfigProto(graph_options=graph_options)
with session.Session(config=config) as sess:
if input_saver_def: if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def) saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint) saver.restore(sess, input_checkpoint)
...@@ -95,52 +111,58 @@ def freeze_graph_with_def_protos( ...@@ -95,52 +111,58 @@ def freeze_graph_with_def_protos(
return output_graph_def return output_graph_def
def get_frozen_graph_def(inference_graph_def, use_moving_averages,
input_checkpoint, output_node_names):
"""Freezes all variables in a graph definition."""
saver = None
if use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
else:
saver = tf.train.Saver()
frozen_graph_def = freeze_graph_with_def_protos( def _image_tensor_input_placeholder():
input_graph_def=inference_graph_def, """Returns placeholder and input node that accepts a batch of uint8 images."""
input_saver_def=saver.as_saver_def(), input_tensor = tf.placeholder(dtype=tf.uint8,
input_checkpoint=input_checkpoint, shape=(None, None, None, 3),
output_node_names=output_node_names, name='image_tensor')
restore_op_name='save/restore_all', return input_tensor, input_tensor
filename_tensor_name='save/Const:0',
clear_devices=True,
initializer_nodes='')
return frozen_graph_def
# TODO: Support batch tf example inputs.
def _tf_example_input_placeholder(): def _tf_example_input_placeholder():
tf_example_placeholder = tf.placeholder( """Returns input that accepts a batch of strings with tf examples.
tf.string, shape=[], name='tf_example')
tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
tf_example_placeholder)
image = tensor_dict[fields.InputDataFields.image]
return tf.expand_dims(image, axis=0)
def _image_tensor_input_placeholder(): Returns:
return tf.placeholder(dtype=tf.uint8, a tuple of placeholder and input nodes that output decoded images.
shape=(1, None, None, 3), """
name='image_tensor') batch_tf_example_placeholder = tf.placeholder(
tf.string, shape=[None], name='tf_example')
def decode(tf_example_string_tensor):
tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
tf_example_string_tensor)
image_tensor = tensor_dict[fields.InputDataFields.image]
return image_tensor
return (batch_tf_example_placeholder,
tf.map_fn(decode,
elems=batch_tf_example_placeholder,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False))
def _encoded_image_string_tensor_input_placeholder(): def _encoded_image_string_tensor_input_placeholder():
image_str = tf.placeholder(dtype=tf.string, """Returns input that accepts a batch of PNG or JPEG strings.
shape=[],
Returns:
a tuple of placeholder and input nodes that output decoded images.
"""
batch_image_str_placeholder = tf.placeholder(
dtype=tf.string,
shape=[None],
name='encoded_image_string_tensor') name='encoded_image_string_tensor')
image_tensor = tf.image.decode_image(image_str, channels=3) def decode(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
image_tensor.set_shape((None, None, 3)) image_tensor.set_shape((None, None, 3))
return tf.expand_dims(image_tensor, axis=0) return image_tensor
return (batch_image_str_placeholder,
tf.map_fn(
decode,
elems=batch_image_str_placeholder,
dtype=tf.uint8,
parallel_iterations=32,
back_prop=False))
input_placeholder_fn_map = { input_placeholder_fn_map = {
...@@ -151,7 +173,8 @@ input_placeholder_fn_map = { ...@@ -151,7 +173,8 @@ input_placeholder_fn_map = {
} }
def _add_output_tensor_nodes(postprocessed_tensors): def _add_output_tensor_nodes(postprocessed_tensors,
output_collection_name='inference_op'):
"""Adds output nodes for detection boxes and scores. """Adds output nodes for detection boxes and scores.
Adds the following nodes for output tensors - Adds the following nodes for output tensors -
...@@ -174,6 +197,7 @@ def _add_output_tensor_nodes(postprocessed_tensors): ...@@ -174,6 +197,7 @@ def _add_output_tensor_nodes(postprocessed_tensors):
'detection_masks': [batch, max_detections, mask_height, mask_width] 'detection_masks': [batch, max_detections, mask_height, mask_width]
(optional). (optional).
'num_detections': [batch] 'num_detections': [batch]
output_collection_name: Name of collection to add output tensors to.
Returns: Returns:
A tensor dict containing the added output tensor nodes. A tensor dict containing the added output tensor nodes.
...@@ -191,53 +215,29 @@ def _add_output_tensor_nodes(postprocessed_tensors): ...@@ -191,53 +215,29 @@ def _add_output_tensor_nodes(postprocessed_tensors):
outputs['num_detections'] = tf.identity(num_detections, name='num_detections') outputs['num_detections'] = tf.identity(num_detections, name='num_detections')
if masks is not None: if masks is not None:
outputs['detection_masks'] = tf.identity(masks, name='detection_masks') outputs['detection_masks'] = tf.identity(masks, name='detection_masks')
for output_key in outputs:
tf.add_to_collection(output_collection_name, outputs[output_key])
if masks is not None:
tf.add_to_collection(output_collection_name, outputs['detection_masks'])
return outputs return outputs
def _write_inference_graph(inference_graph_path, def _write_frozen_graph(frozen_graph_path, frozen_graph_def):
checkpoint_path=None, """Writes frozen graph to disk.
use_moving_averages=False,
output_node_names=(
'num_detections,detection_scores,'
'detection_boxes,detection_classes')):
"""Writes inference graph to disk with the option to bake in weights.
If checkpoint_path is not None bakes the weights into the graph thereby
eliminating the need of checkpoint files during inference. If the model
was trained with moving averages, setting use_moving_averages to true
restores the moving averages, otherwise the original set of variables
is restored.
Args: Args:
inference_graph_path: Path to write inference graph. frozen_graph_path: Path to write inference graph.
checkpoint_path: Optional path to the checkpoint file. frozen_graph_def: tf.GraphDef holding frozen graph.
use_moving_averages: Whether to export the original or the moving averages
of the trainable variables from the checkpoint.
output_node_names: Output tensor names, defaults are: num_detections,
detection_scores, detection_boxes, detection_classes.
""" """
inference_graph_def = tf.get_default_graph().as_graph_def() with gfile.GFile(frozen_graph_path, 'wb') as f:
if checkpoint_path: f.write(frozen_graph_def.SerializeToString())
output_graph_def = get_frozen_graph_def( logging.info('%d ops in the final graph.', len(frozen_graph_def.node))
inference_graph_def=inference_graph_def,
use_moving_averages=use_moving_averages,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
)
with gfile.GFile(inference_graph_path, 'wb') as f:
f.write(output_graph_def.SerializeToString())
logging.info('%d ops in the final graph.', len(output_graph_def.node))
return
tf.train.write_graph(inference_graph_def,
os.path.dirname(inference_graph_path),
os.path.basename(inference_graph_path),
as_text=False)
def _write_saved_model(saved_model_path,
def _write_saved_model(inference_graph_path, inputs, outputs, frozen_graph_def,
checkpoint_path=None, use_moving_averages=False): inputs,
outputs):
"""Writes SavedModel to disk. """Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby If checkpoint_path is not None bakes the weights into the graph thereby
...@@ -247,30 +247,17 @@ def _write_saved_model(inference_graph_path, inputs, outputs, ...@@ -247,30 +247,17 @@ def _write_saved_model(inference_graph_path, inputs, outputs,
is restored. is restored.
Args: Args:
inference_graph_path: Path to write inference graph. saved_model_path: Path to write SavedModel.
frozen_graph_def: tf.GraphDef holding frozen graph.
inputs: The input image tensor to use for detection. inputs: The input image tensor to use for detection.
outputs: A tensor dictionary containing the outputs of a DetectionModel. outputs: A tensor dictionary containing the outputs of a DetectionModel.
checkpoint_path: Optional path to the checkpoint file.
use_moving_averages: Whether to export the original or the moving averages
of the trainable variables from the checkpoint.
""" """
inference_graph_def = tf.get_default_graph().as_graph_def()
checkpoint_graph_def = None
if checkpoint_path:
output_node_names = ','.join(outputs.keys())
checkpoint_graph_def = get_frozen_graph_def(
inference_graph_def=inference_graph_def,
use_moving_averages=use_moving_averages,
input_checkpoint=checkpoint_path,
output_node_names=output_node_names
)
with tf.Graph().as_default(): with tf.Graph().as_default():
with session.Session() as sess: with session.Session() as sess:
tf.import_graph_def(checkpoint_graph_def) tf.import_graph_def(frozen_graph_def, name='')
builder = tf.saved_model.builder.SavedModelBuilder(inference_graph_path) builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)
tensor_info_inputs = { tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inputs)} 'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
...@@ -294,46 +281,96 @@ def _write_saved_model(inference_graph_path, inputs, outputs, ...@@ -294,46 +281,96 @@ def _write_saved_model(inference_graph_path, inputs, outputs,
builder.save() builder.save()
def _write_graph_and_checkpoint(inference_graph_def,
model_path,
input_saver_def,
trained_checkpoint_prefix):
for node in inference_graph_def.node:
node.device = ''
with tf.Graph().as_default():
tf.import_graph_def(inference_graph_def, name='')
with session.Session() as sess:
saver = saver_lib.Saver(saver_def=input_saver_def,
save_relative_paths=True)
saver.restore(sess, trained_checkpoint_prefix)
saver.save(sess, model_path)
def _export_inference_graph(input_type, def _export_inference_graph(input_type,
detection_model, detection_model,
use_moving_averages, use_moving_averages,
checkpoint_path, trained_checkpoint_prefix,
inference_graph_path, output_directory,
export_as_saved_model=False): optimize_graph=False,
output_collection_name='inference_op'):
"""Export helper.""" """Export helper."""
tf.gfile.MakeDirs(output_directory)
frozen_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
saved_model_path = os.path.join(output_directory, 'saved_model')
model_path = os.path.join(output_directory, 'model.ckpt')
if input_type not in input_placeholder_fn_map: if input_type not in input_placeholder_fn_map:
raise ValueError('Unknown input type: {}'.format(input_type)) raise ValueError('Unknown input type: {}'.format(input_type))
inputs = tf.to_float(input_placeholder_fn_map[input_type]()) placeholder_tensor, input_tensors = input_placeholder_fn_map[input_type]()
inputs = tf.to_float(input_tensors)
preprocessed_inputs = detection_model.preprocess(inputs) preprocessed_inputs = detection_model.preprocess(inputs)
output_tensors = detection_model.predict(preprocessed_inputs) output_tensors = detection_model.predict(preprocessed_inputs)
postprocessed_tensors = detection_model.postprocess(output_tensors) postprocessed_tensors = detection_model.postprocess(output_tensors)
outputs = _add_output_tensor_nodes(postprocessed_tensors) outputs = _add_output_tensor_nodes(postprocessed_tensors,
out_node_names = list(outputs.keys()) output_collection_name)
if export_as_saved_model:
_write_saved_model(inference_graph_path, inputs, outputs, checkpoint_path, saver = None
use_moving_averages) if use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
else: else:
_write_inference_graph(inference_graph_path, checkpoint_path, saver = tf.train.Saver()
use_moving_averages, input_saver_def = saver.as_saver_def()
output_node_names=','.join(out_node_names))
_write_graph_and_checkpoint(
inference_graph_def=tf.get_default_graph().as_graph_def(),
model_path=model_path,
input_saver_def=input_saver_def,
trained_checkpoint_prefix=trained_checkpoint_prefix)
frozen_graph_def = freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=trained_checkpoint_prefix,
output_node_names=','.join(outputs.keys()),
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
clear_devices=True,
optimize_graph=optimize_graph,
initializer_nodes='')
_write_frozen_graph(frozen_graph_path, frozen_graph_def)
_write_saved_model(saved_model_path, frozen_graph_def, placeholder_tensor,
outputs)
def export_inference_graph(input_type, pipeline_config, checkpoint_path, def export_inference_graph(input_type,
inference_graph_path, export_as_saved_model=False): pipeline_config,
trained_checkpoint_prefix,
output_directory,
optimize_graph=False,
output_collection_name='inference_op'):
"""Exports inference graph for the model specified in the pipeline config. """Exports inference graph for the model specified in the pipeline config.
Args: Args:
input_type: Type of input for the graph. Can be one of [`image_tensor`, input_type: Type of input for the graph. Can be one of [`image_tensor`,
`tf_example`]. `tf_example`].
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
checkpoint_path: Path to the checkpoint file to freeze. trained_checkpoint_prefix: Path to the trained checkpoint file.
inference_graph_path: Path to write inference graph to. output_directory: Path to write outputs.
export_as_saved_model: If the model should be exported as a SavedModel. If optimize_graph: Whether to optimize graph using Grappler.
false, it is saved as an inference graph. output_collection_name: Name of collection to add output tensors to.
If None, does not add output tensors to a collection.
""" """
detection_model = model_builder.build(pipeline_config.model, detection_model = model_builder.build(pipeline_config.model,
is_training=False) is_training=False)
_export_inference_graph(input_type, detection_model, _export_inference_graph(input_type, detection_model,
pipeline_config.eval_config.use_moving_averages, pipeline_config.eval_config.use_moving_averages,
checkpoint_path, inference_graph_path, trained_checkpoint_prefix, output_directory,
export_as_saved_model) optimize_graph, output_collection_name)
...@@ -43,18 +43,22 @@ class FakeModel(model.DetectionModel): ...@@ -43,18 +43,22 @@ class FakeModel(model.DetectionModel):
def postprocess(self, prediction_dict): def postprocess(self, prediction_dict):
with tf.control_dependencies(prediction_dict.values()): with tf.control_dependencies(prediction_dict.values()):
postprocessed_tensors = { postprocessed_tensors = {
'detection_boxes': tf.constant([[0.0, 0.0, 0.5, 0.5], 'detection_boxes': tf.constant([[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]], tf.float32), [0.5, 0.5, 0.8, 0.8]],
'detection_scores': tf.constant([[0.7, 0.6]], tf.float32), [[0.5, 0.5, 1.0, 1.0],
'detection_classes': tf.constant([[0, 1]], tf.float32), [0.0, 0.0, 0.0, 0.0]]], tf.float32),
'num_detections': tf.constant([2], tf.float32) 'detection_scores': tf.constant([[0.7, 0.6],
[0.9, 0.0]], tf.float32),
'detection_classes': tf.constant([[0, 1],
[1, 0]], tf.float32),
'num_detections': tf.constant([2, 1], tf.float32)
} }
if self._add_detection_masks: if self._add_detection_masks:
postprocessed_tensors['detection_masks'] = tf.constant( postprocessed_tensors['detection_masks'] = tf.constant(
np.arange(32).reshape([2, 4, 4]), tf.float32) np.arange(64).reshape([2, 2, 4, 4]), tf.float32)
return postprocessed_tensors return postprocessed_tensors
def restore_fn(self, checkpoint_path, from_detection_checkpoint): def restore_map(self, checkpoint_path, from_detection_checkpoint):
pass pass
def loss(self, prediction_dict): def loss(self, prediction_dict):
...@@ -69,7 +73,7 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -69,7 +73,7 @@ class ExportInferenceGraphTest(tf.test.TestCase):
with g.as_default(): with g.as_default():
mock_model = FakeModel() mock_model = FakeModel()
preprocessed_inputs = mock_model.preprocess( preprocessed_inputs = mock_model.preprocess(
tf.ones([1, 3, 4, 3], tf.float32)) tf.placeholder(tf.float32, shape=[None, None, None, 3]))
predictions = mock_model.predict(preprocessed_inputs) predictions = mock_model.predict(preprocessed_inputs)
mock_model.postprocess(predictions) mock_model.postprocess(predictions)
if use_moving_averages: if use_moving_averages:
...@@ -103,71 +107,62 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -103,71 +107,62 @@ class ExportInferenceGraphTest(tf.test.TestCase):
return example return example
def test_export_graph_with_image_tensor_input(self): def test_export_graph_with_image_tensor_input(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False)
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
inference_graph_path = os.path.join(self.get_temp_dir(), output_directory = os.path.join(tmp_dir, 'output')
'exported_graph.pbtxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=None, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
def test_export_graph_with_tf_example_input(self): def test_export_graph_with_tf_example_input(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False)
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
inference_graph_path = os.path.join(self.get_temp_dir(), output_directory = os.path.join(tmp_dir, 'output')
'exported_graph.pbtxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='tf_example', input_type='tf_example',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=None, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
def test_export_graph_with_encoded_image_string_input(self): def test_export_graph_with_encoded_image_string_input(self):
with mock.patch.object( tmp_dir = self.get_temp_dir()
model_builder, 'build', autospec=True) as mock_builder: trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
mock_builder.return_value = FakeModel() self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pbtxt')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config,
checkpoint_path=None,
inference_graph_path=inference_graph_path)
def test_export_frozen_graph(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt')
self._save_checkpoint_from_mock_model(checkpoint_path,
use_moving_averages=False) use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(),
'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
def test_export_frozen_graph_with_moving_averages(self): def test_export_graph_with_moving_averages(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=True) use_moving_averages=True)
inference_graph_path = os.path.join(self.get_temp_dir(), output_directory = os.path.join(tmp_dir, 'output')
'exported_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
...@@ -176,15 +171,17 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -176,15 +171,17 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
def test_export_model_with_all_output_nodes(self): def test_export_model_with_all_output_nodes(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
use_moving_averages=False) self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(), use_moving_averages=True)
'exported_graph.pb') output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True) mock_builder.return_value = FakeModel(add_detection_masks=True)
...@@ -192,8 +189,8 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -192,8 +189,8 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path) inference_graph = self._load_inference_graph(inference_graph_path)
with self.test_session(graph=inference_graph): with self.test_session(graph=inference_graph):
inference_graph.get_tensor_by_name('image_tensor:0') inference_graph.get_tensor_by_name('image_tensor:0')
...@@ -204,11 +201,13 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -204,11 +201,13 @@ class ExportInferenceGraphTest(tf.test.TestCase):
inference_graph.get_tensor_by_name('num_detections:0') inference_graph.get_tensor_by_name('num_detections:0')
def test_export_model_with_detection_only_nodes(self): def test_export_model_with_detection_only_nodes(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
use_moving_averages=False) self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(), use_moving_averages=True)
'exported_graph.pb') output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=False) mock_builder.return_value = FakeModel(add_detection_masks=False)
...@@ -216,8 +215,8 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -216,8 +215,8 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path) inference_graph = self._load_inference_graph(inference_graph_path)
with self.test_session(graph=inference_graph): with self.test_session(graph=inference_graph):
inference_graph.get_tensor_by_name('image_tensor:0') inference_graph.get_tensor_by_name('image_tensor:0')
...@@ -229,11 +228,13 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -229,11 +228,13 @@ class ExportInferenceGraphTest(tf.test.TestCase):
inference_graph.get_tensor_by_name('detection_masks:0') inference_graph.get_tensor_by_name('detection_masks:0')
def test_export_and_run_inference_with_image_tensor(self): def test_export_and_run_inference_with_image_tensor(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
use_moving_averages=False) self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(), use_moving_averages=True)
'exported_graph.pb') output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True) mock_builder.return_value = FakeModel(add_detection_masks=True)
...@@ -242,8 +243,8 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -242,8 +243,8 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path) inference_graph = self._load_inference_graph(inference_graph_path)
with self.test_session(graph=inference_graph) as sess: with self.test_session(graph=inference_graph) as sess:
...@@ -253,15 +254,19 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -253,15 +254,19 @@ class ExportInferenceGraphTest(tf.test.TestCase):
classes = inference_graph.get_tensor_by_name('detection_classes:0') classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0') masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0') num_detections = inference_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, masks, num_detections) = sess.run( (boxes_np, scores_np, classes_np, masks_np, num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={image_tensor: np.ones((1, 4, 4, 3)).astype(np.uint8)}) feed_dict={image_tensor: np.ones((2, 4, 4, 3)).astype(np.uint8)})
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5], self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]]) [0.5, 0.5, 0.8, 0.8]],
self.assertAllClose(scores, [[0.7, 0.6]]) [[0.5, 0.5, 1.0, 1.0],
self.assertAllClose(classes, [[1, 2]]) [0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4])) self.assertAllClose(scores_np, [[0.7, 0.6],
self.assertAllClose(num_detections, [2]) [0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def _create_encoded_image_string(self, image_array_np, encoding_format): def _create_encoded_image_string(self, image_array_np, encoding_format):
od_graph = tf.Graph() od_graph = tf.Graph()
...@@ -276,11 +281,13 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -276,11 +281,13 @@ class ExportInferenceGraphTest(tf.test.TestCase):
return encoded_string.eval() return encoded_string.eval()
def test_export_and_run_inference_with_encoded_image_string_tensor(self): def test_export_and_run_inference_with_encoded_image_string_tensor(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
use_moving_averages=False) self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(), use_moving_averages=True)
'exported_graph.pb') output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True) mock_builder.return_value = FakeModel(add_detection_masks=True)
...@@ -289,8 +296,8 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -289,8 +296,8 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='encoded_image_string_tensor', input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path) inference_graph = self._load_inference_graph(inference_graph_path)
jpg_image_str = self._create_encoded_image_string( jpg_image_str = self._create_encoded_image_string(
...@@ -306,23 +313,69 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -306,23 +313,69 @@ class ExportInferenceGraphTest(tf.test.TestCase):
masks = inference_graph.get_tensor_by_name('detection_masks:0') masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0') num_detections = inference_graph.get_tensor_by_name('num_detections:0')
for image_str in [jpg_image_str, png_image_str]: for image_str in [jpg_image_str, png_image_str]:
image_str_batch_np = np.hstack([image_str]* 2)
(boxes_np, scores_np, classes_np, masks_np, (boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run( num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={image_str_tensor: image_str}) feed_dict={image_str_tensor: image_str_batch_np})
self.assertAllClose(boxes_np, [[0.0, 0.0, 0.5, 0.5], self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]]) [0.5, 0.5, 0.8, 0.8]],
self.assertAllClose(scores_np, [[0.7, 0.6]]) [[0.5, 0.5, 1.0, 1.0],
self.assertAllClose(classes_np, [[1, 2]]) [0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(masks_np, np.arange(32).reshape([2, 4, 4])) self.assertAllClose(scores_np, [[0.7, 0.6],
self.assertAllClose(num_detections_np, [2]) [0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def test_raise_runtime_error_on_images_with_different_sizes(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=True)
output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='encoded_image_string_tensor',
pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path)
large_image = self._create_encoded_image_string(
np.ones((4, 4, 3)).astype(np.uint8), 'jpg')
small_image = self._create_encoded_image_string(
np.ones((2, 2, 3)).astype(np.uint8), 'jpg')
image_str_batch_np = np.hstack([large_image, small_image])
with self.test_session(graph=inference_graph) as sess:
image_str_tensor = inference_graph.get_tensor_by_name(
'encoded_image_string_tensor:0')
boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
scores = inference_graph.get_tensor_by_name('detection_scores:0')
classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0')
with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
'^TensorArray has inconsistent shapes.'):
sess.run([boxes, scores, classes, masks, num_detections],
feed_dict={image_str_tensor: image_str_batch_np})
def test_export_and_run_inference_with_tf_example(self): def test_export_and_run_inference_with_tf_example(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
use_moving_averages=False) self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
inference_graph_path = os.path.join(self.get_temp_dir(), use_moving_averages=True)
'exported_graph.pb') output_directory = os.path.join(tmp_dir, 'output')
inference_graph_path = os.path.join(output_directory,
'frozen_inference_graph.pb')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True) mock_builder.return_value = FakeModel(add_detection_masks=True)
...@@ -331,10 +384,12 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -331,10 +384,12 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='tf_example', input_type='tf_example',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path) output_directory=output_directory)
inference_graph = self._load_inference_graph(inference_graph_path) inference_graph = self._load_inference_graph(inference_graph_path)
tf_example_np = np.expand_dims(self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8)), axis=0)
with self.test_session(graph=inference_graph) as sess: with self.test_session(graph=inference_graph) as sess:
tf_example = inference_graph.get_tensor_by_name('tf_example:0') tf_example = inference_graph.get_tensor_by_name('tf_example:0')
boxes = inference_graph.get_tensor_by_name('detection_boxes:0') boxes = inference_graph.get_tensor_by_name('detection_boxes:0')
...@@ -342,23 +397,27 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -342,23 +397,27 @@ class ExportInferenceGraphTest(tf.test.TestCase):
classes = inference_graph.get_tensor_by_name('detection_classes:0') classes = inference_graph.get_tensor_by_name('detection_classes:0')
masks = inference_graph.get_tensor_by_name('detection_masks:0') masks = inference_graph.get_tensor_by_name('detection_masks:0')
num_detections = inference_graph.get_tensor_by_name('num_detections:0') num_detections = inference_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, masks, num_detections) = sess.run( (boxes_np, scores_np, classes_np, masks_np, num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: self._create_tf_example( feed_dict={tf_example: tf_example_np})
np.ones((4, 4, 3)).astype(np.uint8))}) self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]],
[0.5, 0.5, 0.8, 0.8]]) [[0.5, 0.5, 1.0, 1.0],
self.assertAllClose(scores, [[0.7, 0.6]]) [0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(classes, [[1, 2]]) self.assertAllClose(scores_np, [[0.7, 0.6],
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4])) [0.9, 0.0]])
self.assertAllClose(num_detections, [2]) self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def test_export_saved_model_and_run_inference(self): def test_export_saved_model_and_run_inference(self):
checkpoint_path = os.path.join(self.get_temp_dir(), 'model-ckpt') tmp_dir = self.get_temp_dir()
self._save_checkpoint_from_mock_model(checkpoint_path, trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False) use_moving_averages=False)
inference_graph_path = os.path.join(self.get_temp_dir(), output_directory = os.path.join(tmp_dir, 'output')
'saved_model') saved_model_path = os.path.join(output_directory, 'saved_model')
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
...@@ -368,30 +427,84 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -368,30 +427,84 @@ class ExportInferenceGraphTest(tf.test.TestCase):
exporter.export_inference_graph( exporter.export_inference_graph(
input_type='tf_example', input_type='tf_example',
pipeline_config=pipeline_config, pipeline_config=pipeline_config,
checkpoint_path=checkpoint_path, trained_checkpoint_prefix=trained_checkpoint_prefix,
inference_graph_path=inference_graph_path, output_directory=output_directory)
export_as_saved_model=True)
tf_example_np = np.hstack([self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))] * 2)
with tf.Graph().as_default() as od_graph: with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess: with self.test_session(graph=od_graph) as sess:
tf.saved_model.loader.load( tf.saved_model.loader.load(
sess, [tf.saved_model.tag_constants.SERVING], inference_graph_path) sess, [tf.saved_model.tag_constants.SERVING], saved_model_path)
tf_example = od_graph.get_tensor_by_name('import/tf_example:0') tf_example = od_graph.get_tensor_by_name('tf_example:0')
boxes = od_graph.get_tensor_by_name('import/detection_boxes:0') boxes = od_graph.get_tensor_by_name('detection_boxes:0')
scores = od_graph.get_tensor_by_name('import/detection_scores:0') scores = od_graph.get_tensor_by_name('detection_scores:0')
classes = od_graph.get_tensor_by_name('import/detection_classes:0') classes = od_graph.get_tensor_by_name('detection_classes:0')
masks = od_graph.get_tensor_by_name('import/detection_masks:0') masks = od_graph.get_tensor_by_name('detection_masks:0')
num_detections = od_graph.get_tensor_by_name('import/num_detections:0') num_detections = od_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, masks, num_detections) = sess.run( (boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: tf_example_np})
self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
[0.5, 0.5, 0.8, 0.8]],
[[0.5, 0.5, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(scores_np, [[0.7, 0.6],
[0.9, 0.0]])
self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
def test_export_checkpoint_and_run_inference(self):
tmp_dir = self.get_temp_dir()
trained_checkpoint_prefix = os.path.join(tmp_dir, 'model.ckpt')
self._save_checkpoint_from_mock_model(trained_checkpoint_prefix,
use_moving_averages=False)
output_directory = os.path.join(tmp_dir, 'output')
model_path = os.path.join(output_directory, 'model.ckpt')
meta_graph_path = model_path + '.meta'
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel(add_detection_masks=True)
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.eval_config.use_moving_averages = False
exporter.export_inference_graph(
input_type='tf_example',
pipeline_config=pipeline_config,
trained_checkpoint_prefix=trained_checkpoint_prefix,
output_directory=output_directory)
tf_example_np = np.hstack([self._create_tf_example(
np.ones((4, 4, 3)).astype(np.uint8))] * 2)
with tf.Graph().as_default() as od_graph:
with self.test_session(graph=od_graph) as sess:
new_saver = tf.train.import_meta_graph(meta_graph_path)
new_saver.restore(sess, model_path)
tf_example = od_graph.get_tensor_by_name('tf_example:0')
boxes = od_graph.get_tensor_by_name('detection_boxes:0')
scores = od_graph.get_tensor_by_name('detection_scores:0')
classes = od_graph.get_tensor_by_name('detection_classes:0')
masks = od_graph.get_tensor_by_name('detection_masks:0')
num_detections = od_graph.get_tensor_by_name('num_detections:0')
(boxes_np, scores_np, classes_np, masks_np,
num_detections_np) = sess.run(
[boxes, scores, classes, masks, num_detections], [boxes, scores, classes, masks, num_detections],
feed_dict={tf_example: self._create_tf_example( feed_dict={tf_example: tf_example_np})
np.ones((4, 4, 3)).astype(np.uint8))}) self.assertAllClose(boxes_np, [[[0.0, 0.0, 0.5, 0.5],
self.assertAllClose(boxes, [[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 0.8, 0.8]],
[0.5, 0.5, 0.8, 0.8]]) [[0.5, 0.5, 1.0, 1.0],
self.assertAllClose(scores, [[0.7, 0.6]]) [0.0, 0.0, 0.0, 0.0]]])
self.assertAllClose(classes, [[1, 2]]) self.assertAllClose(scores_np, [[0.7, 0.6],
self.assertAllClose(masks, np.arange(32).reshape([2, 4, 4])) [0.9, 0.0]])
self.assertAllClose(num_detections, [2]) self.assertAllClose(classes_np, [[1, 2],
[2, 1]])
self.assertAllClose(masks_np, np.arange(64).reshape([2, 2, 4, 4]))
self.assertAllClose(num_detections_np, [2, 1])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -13,12 +13,11 @@ py_library( ...@@ -13,12 +13,11 @@ py_library(
srcs = ["ssd_meta_arch.py"], srcs = ["ssd_meta_arch.py"],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/core:box_list", "//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor", "//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:model", "//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:target_assigner", "//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:variables_helper", "//tensorflow_models/object_detection/utils:shape_utils",
], ],
) )
...@@ -56,7 +55,7 @@ py_library( ...@@ -56,7 +55,7 @@ py_library(
"//tensorflow_models/object_detection/core:standard_fields", "//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/core:target_assigner", "//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:ops", "//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper", "//tensorflow_models/object_detection/utils:shape_utils",
], ],
) )
......
...@@ -80,7 +80,7 @@ from object_detection.core import post_processing ...@@ -80,7 +80,7 @@ from object_detection.core import post_processing
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import variables_helper from object_detection.utils import shape_utils
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -159,21 +159,19 @@ class FasterRCNNFeatureExtractor(object): ...@@ -159,21 +159,19 @@ class FasterRCNNFeatureExtractor(object):
def restore_from_classification_checkpoint_fn( def restore_from_classification_checkpoint_fn(
self, self,
checkpoint_path,
first_stage_feature_extractor_scope, first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope): second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Args: Args:
checkpoint_path: path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor. feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor. feature extractor.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.global_variables(): for variable in tf.global_variables():
...@@ -182,13 +180,7 @@ class FasterRCNNFeatureExtractor(object): ...@@ -182,13 +180,7 @@ class FasterRCNNFeatureExtractor(object):
if variable.op.name.startswith(scope_name): if variable.op.name.startswith(scope_name):
var_name = variable.op.name.replace(scope_name + '/', '') var_name = variable.op.name.replace(scope_name + '/', '')
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
variables_to_restore = ( return variables_to_restore
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
class FasterRCNNMetaArch(model.DetectionModel): class FasterRCNNMetaArch(model.DetectionModel):
...@@ -774,10 +766,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -774,10 +766,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
A float tensor with shape [A * B, ..., depth] (where the first and last A float tensor with shape [A * B, ..., depth] (where the first and last
dimension are statically defined. dimension are statically defined.
""" """
inputs_shape = inputs.get_shape().as_list() combined_shape = shape_utils.combined_static_and_dynamic_shape(inputs)
flattened_shape = tf.concat([ flattened_shape = tf.stack([combined_shape[0] * combined_shape[1]] +
[inputs_shape[0]*inputs_shape[1]], tf.shape(inputs)[2:-1], combined_shape[2:])
[inputs_shape[-1]]], 0)
return tf.reshape(inputs, flattened_shape) return tf.reshape(inputs, flattened_shape)
def postprocess(self, prediction_dict): def postprocess(self, prediction_dict):
...@@ -875,52 +866,128 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -875,52 +866,128 @@ class FasterRCNNMetaArch(model.DetectionModel):
representing the number of proposals predicted for each image in representing the number of proposals predicted for each image in
the batch. the batch.
""" """
rpn_box_encodings_batch = tf.expand_dims(rpn_box_encodings_batch, axis=2)
rpn_encodings_shape = shape_utils.combined_static_and_dynamic_shape(
rpn_box_encodings_batch)
tiled_anchor_boxes = tf.tile(
tf.expand_dims(anchors, 0), [rpn_encodings_shape[0], 1, 1])
proposal_boxes = self._batch_decode_boxes(rpn_box_encodings_batch,
tiled_anchor_boxes)
proposal_boxes = tf.squeeze(proposal_boxes, axis=2)
rpn_objectness_softmax_without_background = tf.nn.softmax(
rpn_objectness_predictions_with_background_batch)[:, :, 1]
clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]])) clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]]))
if self._is_training: (proposal_boxes, proposal_scores, _, _,
(groundtruth_boxlists, groundtruth_classes_with_background_list num_proposals) = post_processing.batch_multiclass_non_max_suppression(
) = self._format_groundtruth_data(image_shape) tf.expand_dims(proposal_boxes, axis=2),
tf.expand_dims(rpn_objectness_softmax_without_background,
proposal_boxes_list = [] axis=2),
proposal_scores_list = []
num_proposals_list = []
for (batch_index,
(rpn_box_encodings,
rpn_objectness_predictions_with_background)) in enumerate(zip(
tf.unstack(rpn_box_encodings_batch),
tf.unstack(rpn_objectness_predictions_with_background_batch))):
decoded_boxes = self._box_coder.decode(
rpn_box_encodings, box_list.BoxList(anchors))
objectness_scores = tf.unstack(
tf.nn.softmax(rpn_objectness_predictions_with_background), axis=1)[1]
proposal_boxlist = post_processing.multiclass_non_max_suppression(
tf.expand_dims(decoded_boxes.get(), 1),
tf.expand_dims(objectness_scores, 1),
self._first_stage_nms_score_threshold, self._first_stage_nms_score_threshold,
self._first_stage_nms_iou_threshold, self._first_stage_max_proposals, self._first_stage_nms_iou_threshold,
self._first_stage_max_proposals,
self._first_stage_max_proposals,
clip_window=clip_window) clip_window=clip_window)
if self._is_training: if self._is_training:
proposal_boxlist.set(tf.stop_gradient(proposal_boxlist.get())) proposal_boxes = tf.stop_gradient(proposal_boxes)
if not self._hard_example_miner: if not self._hard_example_miner:
proposal_boxlist = self._sample_box_classifier_minibatch( (groundtruth_boxlists, groundtruth_classes_with_background_list,
proposal_boxlist, groundtruth_boxlists[batch_index], ) = self._format_groundtruth_data(image_shape)
groundtruth_classes_with_background_list[batch_index]) (proposal_boxes, proposal_scores,
num_proposals) = self._unpad_proposals_and_sample_box_classifier_batch(
normalized_proposals = box_list_ops.to_normalized_coordinates( proposal_boxes, proposal_scores, num_proposals,
proposal_boxlist, image_shape[1], image_shape[2], groundtruth_boxlists, groundtruth_classes_with_background_list)
check_range=False) # normalize proposal boxes
proposal_boxes_reshaped = tf.reshape(proposal_boxes, [-1, 4])
# pad proposals to max_num_proposals normalized_proposal_boxes_reshaped = box_list_ops.to_normalized_coordinates(
padded_proposals = box_list_ops.pad_or_clip_box_list( box_list.BoxList(proposal_boxes_reshaped),
normalized_proposals, num_boxes=self.max_num_proposals) image_shape[1], image_shape[2], check_range=False).get()
proposal_boxes_list.append(padded_proposals.get()) proposal_boxes = tf.reshape(normalized_proposal_boxes_reshaped,
proposal_scores_list.append( [-1, proposal_boxes.shape[1].value, 4])
padded_proposals.get_field(fields.BoxListFields.scores)) return proposal_boxes, proposal_scores, num_proposals
num_proposals_list.append(tf.minimum(normalized_proposals.num_boxes(),
self.max_num_proposals)) def _unpad_proposals_and_sample_box_classifier_batch(
self,
return (tf.stack(proposal_boxes_list), tf.stack(proposal_scores_list), proposal_boxes,
tf.stack(num_proposals_list)) proposal_scores,
num_proposals,
groundtruth_boxlists,
groundtruth_classes_with_background_list):
"""Unpads proposals and samples a minibatch for second stage.
Args:
proposal_boxes: A float tensor with shape
[batch_size, num_proposals, 4] representing the (potentially zero
padded) proposal boxes for all images in the batch. These boxes are
represented as normalized coordinates.
proposal_scores: A float tensor with shape
[batch_size, num_proposals] representing the (potentially zero
padded) proposal objectness scores for all images in the batch.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
groundtruth_boxlists: A list of BoxLists containing (absolute) coordinates
of the groundtruth boxes.
groundtruth_classes_with_background_list: A list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes+1] containing the
class targets with the 0th index assumed to map to the background class.
Returns:
proposal_boxes: A float tensor with shape
[batch_size, second_stage_batch_size, 4] representing the (potentially
zero padded) proposal boxes for all images in the batch. These boxes
are represented as normalized coordinates.
proposal_scores: A float tensor with shape
[batch_size, second_stage_batch_size] representing the (potentially zero
padded) proposal objectness scores for all images in the batch.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
"""
single_image_proposal_box_sample = []
single_image_proposal_score_sample = []
single_image_num_proposals_sample = []
for (single_image_proposal_boxes,
single_image_proposal_scores,
single_image_num_proposals,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background) in zip(
tf.unstack(proposal_boxes),
tf.unstack(proposal_scores),
tf.unstack(num_proposals),
groundtruth_boxlists,
groundtruth_classes_with_background_list):
static_shape = single_image_proposal_boxes.get_shape()
sliced_static_shape = tf.TensorShape([tf.Dimension(None),
static_shape.dims[-1]])
single_image_proposal_boxes = tf.slice(
single_image_proposal_boxes,
[0, 0],
[single_image_num_proposals, -1])
single_image_proposal_boxes.set_shape(sliced_static_shape)
single_image_proposal_scores = tf.slice(single_image_proposal_scores,
[0],
[single_image_num_proposals])
single_image_boxlist = box_list.BoxList(single_image_proposal_boxes)
single_image_boxlist.add_field(fields.BoxListFields.scores,
single_image_proposal_scores)
sampled_boxlist = self._sample_box_classifier_minibatch(
single_image_boxlist,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background)
sampled_padded_boxlist = box_list_ops.pad_or_clip_box_list(
sampled_boxlist,
num_boxes=self._second_stage_batch_size)
single_image_num_proposals_sample.append(tf.minimum(
sampled_boxlist.num_boxes(),
self._second_stage_batch_size))
bb = sampled_padded_boxlist.get()
single_image_proposal_box_sample.append(bb)
single_image_proposal_score_sample.append(
sampled_padded_boxlist.get_field(fields.BoxListFields.scores))
return (tf.stack(single_image_proposal_box_sample),
tf.stack(single_image_proposal_score_sample),
tf.stack(single_image_num_proposals_sample))
def _format_groundtruth_data(self, image_shape): def _format_groundtruth_data(self, image_shape):
"""Helper function for preparing groundtruth data for target assignment. """Helper function for preparing groundtruth data for target assignment.
...@@ -1074,7 +1141,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1074,7 +1141,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
class_predictions_with_background, class_predictions_with_background,
[-1, self.max_num_proposals, self.num_classes + 1] [-1, self.max_num_proposals, self.num_classes + 1]
) )
refined_decoded_boxes_batch = self._batch_decode_refined_boxes( refined_decoded_boxes_batch = self._batch_decode_boxes(
refined_box_encodings_batch, proposal_boxes) refined_box_encodings_batch, proposal_boxes)
class_predictions_with_background_batch = ( class_predictions_with_background_batch = (
self._second_stage_score_conversion_fn( self._second_stage_score_conversion_fn(
...@@ -1092,19 +1159,26 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1092,19 +1159,26 @@ class FasterRCNNMetaArch(model.DetectionModel):
mask_predictions_batch = tf.reshape( mask_predictions_batch = tf.reshape(
mask_predictions, [-1, self.max_num_proposals, mask_predictions, [-1, self.max_num_proposals,
self.num_classes, mask_height, mask_width]) self.num_classes, mask_height, mask_width])
detections = self._second_stage_nms_fn( (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = self._second_stage_nms_fn(
refined_decoded_boxes_batch, refined_decoded_boxes_batch,
class_predictions_batch, class_predictions_batch,
clip_window=clip_window, clip_window=clip_window,
change_coordinate_frame=True, change_coordinate_frame=True,
num_valid_boxes=num_proposals, num_valid_boxes=num_proposals,
masks=mask_predictions_batch) masks=mask_predictions_batch)
detections = {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
if nmsed_masks is not None:
detections['detection_masks'] = nmsed_masks
if mask_predictions is not None: if mask_predictions is not None:
detections['detection_masks'] = tf.to_float( detections['detection_masks'] = tf.to_float(
tf.greater_equal(detections['detection_masks'], mask_threshold)) tf.greater_equal(detections['detection_masks'], mask_threshold))
return detections return detections
def _batch_decode_refined_boxes(self, refined_box_encodings, proposal_boxes): def _batch_decode_boxes(self, box_encodings, anchor_boxes):
"""Decode tensor of refined box encodings. """Decode tensor of refined box encodings.
Args: Args:
...@@ -1119,15 +1193,33 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1119,15 +1193,33 @@ class FasterRCNNMetaArch(model.DetectionModel):
float tensor representing (padded) refined bounding box predictions float tensor representing (padded) refined bounding box predictions
(for each image in batch, proposal and class). (for each image in batch, proposal and class).
""" """
tiled_proposal_boxes = tf.tile( """Decodes box encodings with respect to the anchor boxes.
tf.expand_dims(proposal_boxes, 2), [1, 1, self.num_classes, 1])
tiled_proposals_boxlist = box_list.BoxList( Args:
tf.reshape(tiled_proposal_boxes, [-1, 4])) box_encodings: a 4-D tensor with shape
[batch_size, num_anchors, num_classes, self._box_coder.code_size]
representing box encodings.
anchor_boxes: [batch_size, num_anchors, 4] representing
decoded bounding boxes.
Returns:
decoded_boxes: a [batch_size, num_anchors, num_classes, 4]
float tensor representing bounding box predictions
(for each image in batch, proposal and class).
"""
combined_shape = shape_utils.combined_static_and_dynamic_shape(
box_encodings)
num_classes = combined_shape[2]
tiled_anchor_boxes = tf.tile(
tf.expand_dims(anchor_boxes, 2), [1, 1, num_classes, 1])
tiled_anchors_boxlist = box_list.BoxList(
tf.reshape(tiled_anchor_boxes, [-1, 4]))
decoded_boxes = self._box_coder.decode( decoded_boxes = self._box_coder.decode(
tf.reshape(refined_box_encodings, [-1, self._box_coder.code_size]), tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
tiled_proposals_boxlist) tiled_anchors_boxlist)
return tf.reshape(decoded_boxes.get(), return tf.reshape(decoded_boxes.get(),
[-1, self.max_num_proposals, self.num_classes, 4]) tf.stack([combined_shape[0], combined_shape[1],
num_classes, 4]))
def loss(self, prediction_dict, scope=None): def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors given prediction tensors. """Compute scalar loss tensors given prediction tensors.
...@@ -1413,25 +1505,22 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1413,25 +1505,22 @@ class FasterRCNNMetaArch(model.DetectionModel):
cls_losses=tf.expand_dims(single_image_cls_loss, 0), cls_losses=tf.expand_dims(single_image_cls_loss, 0),
decoded_boxlist_list=[proposal_boxlist]) decoded_boxlist_list=[proposal_boxlist])
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args: Args:
checkpoint_path: path to checkpoint to restore. from_detection_checkpoint: whether to restore from a full detection
from_detection_checkpoint: whether to restore from a detection checkpoint checkpoint (with compatible variable names) or to restore from a
(with compatible variable names) or to restore from a classification classification checkpoint for initialization prior to training.
checkpoint for initialization prior to training. Note that when
from_detection_checkpoint=True, the current implementation only
supports restoration from an (exactly) identical model (with exception
of the num_classes parameter).
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
if not from_detection_checkpoint: if not from_detection_checkpoint:
return self._feature_extractor.restore_from_classification_checkpoint_fn( return self._feature_extractor.restore_from_classification_checkpoint_fn(
checkpoint_path,
self.first_stage_feature_extractor_scope, self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope) self.second_stage_feature_extractor_scope)
...@@ -1439,13 +1528,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1439,13 +1528,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
variables_to_restore.append(slim.get_or_create_global_step()) variables_to_restore.append(slim.get_or_create_global_step())
# Only load feature extractor variables to be consistent with loading from # Only load feature extractor variables to be consistent with loading from
# a classification checkpoint. # a classification checkpoint.
first_stage_variables = tf.contrib.framework.filter_variables( feature_extractor_variables = tf.contrib.framework.filter_variables(
variables_to_restore, variables_to_restore,
include_patterns=[self.first_stage_feature_extractor_scope, include_patterns=[self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope]) self.second_stage_feature_extractor_scope])
return {var.op.name: var for var in feature_extractor_variables}
saver = tf.train.Saver(first_stage_variables)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -226,61 +226,47 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -226,61 +226,47 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return self._get_model(self._get_second_stage_box_predictor( return self._get_model(self._get_second_stage_box_predictor(
num_classes=num_classes, is_training=is_training), **common_kwargs) num_classes=num_classes, is_training=is_training), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only( def test_predict_correct_shapes_in_inference_mode_both_stages(
self): self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=True, second_stage_batch_size=2)
batch_size = 2 batch_size = 2
height = 10 image_size = 10
width = 12 input_shapes = [(batch_size, image_size, image_size, 3),
input_image_shape = (batch_size, height, width, 3) (None, image_size, image_size, 3),
(batch_size, None, None, 3),
preprocessed_inputs = tf.placeholder(dtype=tf.float32, (None, None, None, 3)]
shape=(batch_size, None, None, 3)) expected_num_anchors = image_size * image_size * 3 * 3
prediction_dict = model.predict(preprocessed_inputs) expected_shapes = {
'rpn_box_predictor_features':
# In inference mode, anchors are clipped to the image window, but not (2, image_size, image_size, 512),
# pruned. Since MockFasterRCNN.extract_proposal_features returns a 'rpn_features_to_crop': (2, image_size, image_size, 3),
# tensor with the same shape as its input, the expected number of anchors 'image_shape': (4,),
# is height * width * the number of anchors per location (i.e. 3x3). 'rpn_box_encodings': (2, expected_num_anchors, 4),
expected_num_anchors = height * width * 3 * 3
expected_output_keys = set([
'rpn_box_predictor_features', 'rpn_features_to_crop', 'image_shape',
'rpn_box_encodings', 'rpn_objectness_predictions_with_background',
'anchors'])
expected_output_shapes = {
'rpn_box_predictor_features': (batch_size, height, width, 512),
'rpn_features_to_crop': (batch_size, height, width, 3),
'rpn_box_encodings': (batch_size, expected_num_anchors, 4),
'rpn_objectness_predictions_with_background': 'rpn_objectness_predictions_with_background':
(batch_size, expected_num_anchors, 2), (2, expected_num_anchors, 2),
'anchors': (expected_num_anchors, 4) 'anchors': (expected_num_anchors, 4),
'refined_box_encodings': (2 * 8, 2, 4),
'class_predictions_with_background': (2 * 8, 2 + 1),
'num_proposals': (2,),
'proposal_boxes': (2, 8, 4),
} }
for input_shape in input_shapes:
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=False,
second_stage_batch_size=2)
preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape)
result_tensor_dict = model.predict(preprocessed_inputs)
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
with self.test_session() as sess: with self.test_session(graph=test_graph) as sess:
sess.run(init_op) sess.run(init_op)
prediction_out = sess.run(prediction_dict, tensor_dict_out = sess.run(result_tensor_dict, feed_dict={
feed_dict={
preprocessed_inputs: preprocessed_inputs:
np.zeros(input_image_shape) np.zeros((batch_size, image_size, image_size, 3))})
}) self.assertEqual(set(tensor_dict_out.keys()),
set(expected_shapes.keys()))
self.assertEqual(set(prediction_out.keys()), expected_output_keys) for key in expected_shapes:
self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key])
self.assertAllEqual(prediction_out['image_shape'], input_image_shape)
for output_key, expected_shape in expected_output_shapes.items():
self.assertAllEqual(prediction_out[output_key].shape, expected_shape)
# Check that anchors are clipped to window.
anchors = prediction_out['anchors']
self.assertTrue(np.all(np.greater_equal(anchors, 0)))
self.assertTrue(np.all(np.less_equal(anchors[:, 0], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 1], width)))
self.assertTrue(np.all(np.less_equal(anchors[:, 2], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 3], width)))
def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self): def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self):
test_graph = tf.Graph() test_graph = tf.Graph()
...@@ -535,35 +521,67 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -535,35 +521,67 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
expected_num_proposals) expected_num_proposals)
def test_postprocess_second_stage_only_inference_mode(self): def test_postprocess_second_stage_only_inference_mode(self):
model = self._build_model( num_proposals_shapes = [(2), (None)]
is_training=False, first_stage_only=False, second_stage_batch_size=6) refined_box_encodings_shapes = [(16, 2, 4), (None, 2, 4)]
class_predictions_with_background_shapes = [(16, 3), (None, 3)]
proposal_boxes_shapes = [(2, 8, 4), (None, 8, 4)]
batch_size = 2 batch_size = 2
image_shape = np.array((2, 36, 48, 3), dtype=np.int32)
for (num_proposals_shape, refined_box_encoding_shape,
class_predictions_with_background_shape,
proposal_boxes_shape) in zip(num_proposals_shapes,
refined_box_encodings_shapes,
class_predictions_with_background_shapes,
proposal_boxes_shapes):
tf_graph = tf.Graph()
with tf_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=False,
second_stage_batch_size=6)
total_num_padded_proposals = batch_size * model.max_num_proposals total_num_padded_proposals = batch_size * model.max_num_proposals
proposal_boxes = tf.constant( proposal_boxes = np.array(
[[[1, 1, 2, 3], [[[1, 1, 2, 3],
[0, 0, 1, 1], [0, 0, 1, 1],
[.5, .5, .6, .6], [.5, .5, .6, .6],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0]], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]],
[[2, 3, 6, 8], [[2, 3, 6, 8],
[1, 2, 5, 3], [1, 2, 5, 3],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]], dtype=tf.float32) 4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]])
num_proposals = tf.constant([3, 2], dtype=tf.int32) num_proposals = np.array([3, 2], dtype=np.int32)
refined_box_encodings = tf.zeros( refined_box_encodings = np.zeros(
[total_num_padded_proposals, model.num_classes, 4], dtype=tf.float32) [total_num_padded_proposals, model.num_classes, 4])
class_predictions_with_background = tf.ones( class_predictions_with_background = np.ones(
[total_num_padded_proposals, model.num_classes+1], dtype=tf.float32) [total_num_padded_proposals, model.num_classes+1])
image_shape = tf.constant([batch_size, 36, 48, 3], dtype=tf.int32)
num_proposals_placeholder = tf.placeholder(tf.int32,
shape=num_proposals_shape)
refined_box_encodings_placeholder = tf.placeholder(
tf.float32, shape=refined_box_encoding_shape)
class_predictions_with_background_placeholder = tf.placeholder(
tf.float32, shape=class_predictions_with_background_shape)
proposal_boxes_placeholder = tf.placeholder(
tf.float32, shape=proposal_boxes_shape)
image_shape_placeholder = tf.placeholder(tf.int32, shape=(4))
detections = model.postprocess({ detections = model.postprocess({
'refined_box_encodings': refined_box_encodings, 'refined_box_encodings': refined_box_encodings_placeholder,
'class_predictions_with_background': class_predictions_with_background, 'class_predictions_with_background':
'num_proposals': num_proposals, class_predictions_with_background_placeholder,
'proposal_boxes': proposal_boxes, 'num_proposals': num_proposals_placeholder,
'image_shape': image_shape 'proposal_boxes': proposal_boxes_placeholder,
'image_shape': image_shape_placeholder,
})
with self.test_session(graph=tf_graph) as sess:
detections_out = sess.run(
detections,
feed_dict={
refined_box_encodings_placeholder: refined_box_encodings,
class_predictions_with_background_placeholder:
class_predictions_with_background,
num_proposals_placeholder: num_proposals,
proposal_boxes_placeholder: proposal_boxes,
image_shape_placeholder: image_shape
}) })
with self.test_session() as sess:
detections_out = sess.run(detections)
self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4]) self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4])
self.assertAllClose(detections_out['detection_scores'], self.assertAllClose(detections_out['detection_scores'],
[[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]) [[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
...@@ -571,6 +589,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -571,6 +589,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]]) [[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]])
self.assertAllClose(detections_out['num_detections'], [5, 4]) self.assertAllClose(detections_out['num_detections'], [5, 4])
def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
def test_loss_first_stage_only_mode(self): def test_loss_first_stage_only_mode(self):
model = self._build_model( model = self._build_model(
is_training=True, first_stage_only=True, second_stage_batch_size=6) is_training=True, first_stage_only=True, second_stage_batch_size=6)
...@@ -957,7 +986,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -957,7 +986,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
exp_loc_loss) exp_loc_loss)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0) self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_restore_fn_classification(self): def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables. # Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph() test_graph_classification = tf.Graph()
with test_graph_classification.as_default(): with test_graph_classification.as_default():
...@@ -986,12 +1015,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -986,12 +1015,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs = model.preprocess(inputs) preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs) prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict) model.postprocess(prediction_dict)
restore_fn = model.restore_fn(saved_model_path, var_map = model.restore_map(from_detection_checkpoint=False)
from_detection_checkpoint=False) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model.second_stage_feature_extractor_scope,
var.name)
def test_restore_fn_detection(self): def test_restore_map_for_detection_ckpt(self):
# Define first detection graph and save variables. # Define first detection graph and save variables.
test_graph_detection1 = tf.Graph() test_graph_detection1 = tf.Graph()
with test_graph_detection1.as_default(): with test_graph_detection1.as_default():
...@@ -1022,10 +1056,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -1022,10 +1056,11 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
preprocessed_inputs2 = model2.preprocess(inputs2) preprocessed_inputs2 = model2.preprocess(inputs2)
prediction_dict2 = model2.predict(preprocessed_inputs2) prediction_dict2 = model2.predict(preprocessed_inputs2)
model2.postprocess(prediction_dict2) model2.postprocess(prediction_dict2)
restore_fn = model2.restore_fn(saved_model_path, var_map = model2.restore_map(from_detection_checkpoint=True)
from_detection_checkpoint=True) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name) self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model2.second_stage_feature_extractor_scope, self.assertNotIn(model2.second_stage_feature_extractor_scope,
......
...@@ -23,13 +23,12 @@ from abc import abstractmethod ...@@ -23,13 +23,12 @@ from abc import abstractmethod
import re import re
import tensorflow as tf import tensorflow as tf
from object_detection.core import box_coder as bcoder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_predictor as bpredictor from object_detection.core import box_predictor as bpredictor
from object_detection.core import model from object_detection.core import model
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.utils import variables_helper from object_detection.utils import shape_utils
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -324,7 +323,8 @@ class SSDMetaArch(model.DetectionModel): ...@@ -324,7 +323,8 @@ class SSDMetaArch(model.DetectionModel):
a list of pairs (height, width) for each feature map in feature_maps a list of pairs (height, width) for each feature map in feature_maps
""" """
feature_map_shapes = [ feature_map_shapes = [
feature_map.get_shape().as_list() for feature_map in feature_maps shape_utils.combined_static_and_dynamic_shape(
feature_map) for feature_map in feature_maps
] ]
return [(shape[1], shape[2]) for shape in feature_map_shapes] return [(shape[1], shape[2]) for shape in feature_map_shapes]
...@@ -365,8 +365,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -365,8 +365,7 @@ class SSDMetaArch(model.DetectionModel):
with tf.name_scope('Postprocessor'): with tf.name_scope('Postprocessor'):
box_encodings = prediction_dict['box_encodings'] box_encodings = prediction_dict['box_encodings']
class_predictions = prediction_dict['class_predictions_with_background'] class_predictions = prediction_dict['class_predictions_with_background']
detection_boxes = bcoder.batch_decode(box_encodings, self._box_coder, detection_boxes = self._batch_decode(box_encodings)
self.anchors)
detection_boxes = tf.expand_dims(detection_boxes, axis=2) detection_boxes = tf.expand_dims(detection_boxes, axis=2)
class_predictions_without_background = tf.slice(class_predictions, class_predictions_without_background = tf.slice(class_predictions,
...@@ -375,10 +374,14 @@ class SSDMetaArch(model.DetectionModel): ...@@ -375,10 +374,14 @@ class SSDMetaArch(model.DetectionModel):
detection_scores = self._score_conversion_fn( detection_scores = self._score_conversion_fn(
class_predictions_without_background) class_predictions_without_background)
clip_window = tf.constant([0, 0, 1, 1], tf.float32) clip_window = tf.constant([0, 0, 1, 1], tf.float32)
detections = self._non_max_suppression_fn(detection_boxes, (nmsed_boxes, nmsed_scores, nmsed_classes, _,
num_detections) = self._non_max_suppression_fn(detection_boxes,
detection_scores, detection_scores,
clip_window=clip_window) clip_window=clip_window)
return detections return {'detection_boxes': nmsed_boxes,
'detection_scores': nmsed_scores,
'detection_classes': nmsed_classes,
'num_detections': tf.to_float(num_detections)}
def loss(self, prediction_dict, scope=None): def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors with respect to provided groundtruth. """Compute scalar loss tensors with respect to provided groundtruth.
...@@ -546,8 +549,7 @@ class SSDMetaArch(model.DetectionModel): ...@@ -546,8 +549,7 @@ class SSDMetaArch(model.DetectionModel):
tf.slice(prediction_dict['class_predictions_with_background'], tf.slice(prediction_dict['class_predictions_with_background'],
[0, 0, 1], class_pred_shape), class_pred_shape) [0, 0, 1], class_pred_shape), class_pred_shape)
decoded_boxes = bcoder.batch_decode(prediction_dict['box_encodings'], decoded_boxes = self._batch_decode(prediction_dict['box_encodings'])
self._box_coder, self.anchors)
decoded_box_tensors_list = tf.unstack(decoded_boxes) decoded_box_tensors_list = tf.unstack(decoded_boxes)
class_prediction_list = tf.unstack(class_predictions) class_prediction_list = tf.unstack(class_predictions)
decoded_boxlist_list = [] decoded_boxlist_list = []
...@@ -562,33 +564,51 @@ class SSDMetaArch(model.DetectionModel): ...@@ -562,33 +564,51 @@ class SSDMetaArch(model.DetectionModel):
decoded_boxlist_list=decoded_boxlist_list, decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list) match_list=match_list)
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def _batch_decode(self, box_encodings):
"""Return callable for loading a checkpoint into the tensorflow graph. """Decodes a batch of box encodings with respect to the anchors.
Args:
box_encodings: A float32 tensor of shape
[batch_size, num_anchors, box_code_size] containing box encodings.
Returns:
decoded_boxes: A float32 tensor of shape
[batch_size, num_anchors, 4] containing the decoded boxes.
"""
combined_shape = shape_utils.combined_static_and_dynamic_shape(
box_encodings)
batch_size = combined_shape[0]
tiled_anchor_boxes = tf.tile(
tf.expand_dims(self.anchors.get(), 0), [batch_size, 1, 1])
tiled_anchors_boxlist = box_list.BoxList(
tf.reshape(tiled_anchor_boxes, [-1, self._box_coder.code_size]))
decoded_boxes = self._box_coder.decode(
tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
tiled_anchors_boxlist)
return tf.reshape(decoded_boxes.get(),
tf.stack([combined_shape[0], combined_shape[1],
4]))
def restore_map(self, from_detection_checkpoint=True):
"""Returns a map of variables to load from a foreign checkpoint.
See parent class for details.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.all_variables(): for variable in tf.all_variables():
if variable.op.name.startswith(self._extract_features_scope): if variable.op.name.startswith(self._extract_features_scope):
var_name = variable.op.name var_name = variable.op.name
if not from_detection_checkpoint: if not from_detection_checkpoint:
var_name = ( var_name = (re.split('^' + self._extract_features_scope + '/',
re.split('^' + self._extract_features_scope + '/', var_name)[-1]) var_name)[-1])
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
# TODO: Load variables selectively using scopes. return variables_to_restore
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -116,24 +116,46 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -116,24 +116,46 @@ class SsdMetaArchTest(tf.test.TestCase):
localization_loss_weight, normalize_loss_by_num_matches, localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner) hard_example_miner)
def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = self._model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
def test_predict_results_have_correct_keys_and_shapes(self): def test_predict_results_have_correct_keys_and_shapes(self):
batch_size = 3 batch_size = 3
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3), image_size = 2
dtype=tf.float32) input_shapes = [(batch_size, image_size, image_size, 3),
prediction_dict = self._model.predict(preprocessed_input) (None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]
expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1)
for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
self.assertTrue('box_encodings' in prediction_dict) self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict) self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict) self.assertTrue('feature_maps' in prediction_dict)
expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1)
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
with self.test_session() as sess: with self.test_session(graph=tf_graph) as sess:
sess.run(init_op) sess.run(init_op)
prediction_out = sess.run(prediction_dict) prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllEqual(prediction_out['box_encodings'].shape, self.assertAllEqual(prediction_out['box_encodings'].shape,
expected_box_encodings_shape_out) expected_box_encodings_shape_out)
self.assertAllEqual( self.assertAllEqual(
...@@ -142,10 +164,11 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -142,10 +164,11 @@ class SsdMetaArchTest(tf.test.TestCase):
def test_postprocess_results_are_correct(self): def test_postprocess_results_are_correct(self):
batch_size = 2 batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3), image_size = 2
dtype=tf.float32) input_shapes = [(batch_size, image_size, image_size, 3),
prediction_dict = self._model.predict(preprocessed_input) (None, image_size, image_size, 3),
detections = self._model.postprocess(prediction_dict) (batch_size, None, None, 3),
(None, None, None, 3)]
expected_boxes = np.array([[[0, 0, .5, .5], expected_boxes = np.array([[[0, 0, .5, .5],
[0, .5, .5, 1], [0, .5, .5, 1],
...@@ -163,15 +186,25 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -163,15 +186,25 @@ class SsdMetaArchTest(tf.test.TestCase):
[0, 0, 0, 0, 0]]) [0, 0, 0, 0, 0]])
expected_num_detections = np.array([4, 4]) expected_num_detections = np.array([4, 4])
for input_shape in input_shapes:
tf_graph = tf.Graph()
with tf_graph.as_default():
preprocessed_input_placeholder = tf.placeholder(tf.float32,
shape=input_shape)
prediction_dict = self._model.predict(preprocessed_input_placeholder)
detections = self._model.postprocess(prediction_dict)
self.assertTrue('detection_boxes' in detections) self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections) self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections) self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections) self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
with self.test_session() as sess: with self.test_session(graph=tf_graph) as sess:
sess.run(init_op) sess.run(init_op)
detections_out = sess.run(detections) detections_out = sess.run(detections,
feed_dict={
preprocessed_input_placeholder:
np.random.uniform(
size=(batch_size, 2, 2, 3))})
self.assertAllClose(detections_out['detection_boxes'], expected_boxes) self.assertAllClose(detections_out['detection_boxes'], expected_boxes)
self.assertAllClose(detections_out['detection_scores'], expected_scores) self.assertAllClose(detections_out['detection_scores'], expected_scores)
self.assertAllClose(detections_out['detection_classes'], expected_classes) self.assertAllClose(detections_out['detection_classes'], expected_classes)
...@@ -207,20 +240,21 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -207,20 +240,21 @@ class SsdMetaArchTest(tf.test.TestCase):
self.assertAllClose(losses_out['classification_loss'], self.assertAllClose(losses_out['classification_loss'],
expected_classification_loss) expected_classification_loss)
def test_restore_fn_detection(self): def test_restore_map_for_detection_ckpt(self):
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
saver = tf_saver.Saver() saver = tf_saver.Saver()
save_path = self.get_temp_dir() save_path = self.get_temp_dir()
with self.test_session() as sess: with self.test_session() as sess:
sess.run(init_op) sess.run(init_op)
saved_model_path = saver.save(sess, save_path) saved_model_path = saver.save(sess, save_path)
restore_fn = self._model.restore_fn(saved_model_path, var_map = self._model.restore_map(from_detection_checkpoint=True)
from_detection_checkpoint=True) self.assertIsInstance(var_map, dict)
restore_fn(sess) saver = tf.train.Saver(var_map)
saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name) self.assertNotIn('FeatureExtractor', var.name)
def test_restore_fn_classification(self): def test_restore_map_for_classification_ckpt(self):
# Define mock tensorflow classification graph and save variables. # Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph() test_graph_classification = tf.Graph()
with test_graph_classification.as_default(): with test_graph_classification.as_default():
...@@ -246,10 +280,11 @@ class SsdMetaArchTest(tf.test.TestCase): ...@@ -246,10 +280,11 @@ class SsdMetaArchTest(tf.test.TestCase):
preprocessed_inputs = self._model.preprocess(inputs) preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs) prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict) self._model.postprocess(prediction_dict)
restore_fn = self._model.restore_fn(saved_model_path, var_map = self._model.restore_map(from_detection_checkpoint=False)
from_detection_checkpoint=False) self.assertIsInstance(var_map, dict)
saver = tf.train.Saver(var_map)
with self.test_session() as sess: with self.test_session() as sess:
restore_fn(sess) saver.restore(sess, saved_model_path)
for var in sess.run(tf.report_uninitialized_variables()): for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name) self.assertNotIn('FeatureExtractor', var.name)
......
...@@ -94,7 +94,6 @@ py_library( ...@@ -94,7 +94,6 @@ py_library(
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch", "//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:inception_resnet_v2", "//tensorflow_models/slim:inception_resnet_v2",
], ],
) )
......
...@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012) ...@@ -25,7 +25,6 @@ Huang et al. (https://arxiv.org/abs/1611.10012)
import tensorflow as tf import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets import inception_resnet_v2 from nets import inception_resnet_v2
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor( ...@@ -168,30 +167,30 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
def restore_from_classification_checkpoint_fn( def restore_from_classification_checkpoint_fn(
self, self,
checkpoint_path,
first_stage_feature_extractor_scope, first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope): second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Note that this overrides the default implementation in Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints. InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as TODO: revisit whether it's possible to force the
created in `_extract_box_classifier_features` to start counting at 2 (e.g. `Repeat` namescope as created in `_extract_box_classifier_features` to
`Repeat_2`) so that the default restore_fn can be used. start counting at 2 (e.g. `Repeat_2`) so that the default restore_fn can
be used.
Args: Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor. feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor. feature extractor.
Returns: Returns:
a callable which takes a tf.Session as input and loads a checkpoint when A dict mapping variable names (to load from a checkpoint) to variables in
run. the model graph.
""" """
variables_to_restore = {} variables_to_restore = {}
for variable in tf.global_variables(): for variable in tf.global_variables():
if variable.op.name.startswith( if variable.op.name.startswith(
...@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor( ...@@ -207,10 +206,4 @@ class FasterRCNNInceptionResnetV2FeatureExtractor(
var_name = var_name.replace( var_name = var_name.replace(
second_stage_feature_extractor_scope + '/', '') second_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable variables_to_restore[var_name] = variable
variables_to_restore = ( return variables_to_restore
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
...@@ -211,9 +211,15 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task, ...@@ -211,9 +211,15 @@ def train(create_tensor_dict_fn, create_model_fn, train_config, master, task,
# Create ops required to initialize the model from a given checkpoint. # Create ops required to initialize the model from a given checkpoint.
init_fn = None init_fn = None
if train_config.fine_tune_checkpoint: if train_config.fine_tune_checkpoint:
init_fn = detection_model.restore_fn( var_map = detection_model.restore_map(
train_config.fine_tune_checkpoint,
from_detection_checkpoint=train_config.from_detection_checkpoint) from_detection_checkpoint=train_config.from_detection_checkpoint)
available_var_map = (variables_helper.
get_variables_available_in_checkpoint(
var_map, train_config.fine_tune_checkpoint))
init_saver = tf.train.Saver(available_var_map)
def initializer_fn(sess):
init_saver.restore(sess, train_config.fine_tune_checkpoint)
init_fn = initializer_fn
with tf.device(deploy_config.optimizer_device()): with tf.device(deploy_config.optimizer_device()):
total_loss, grads_and_vars = model_deploy.optimize_clones( total_loss, grads_and_vars = model_deploy.optimize_clones(
......
...@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel): ...@@ -139,21 +139,18 @@ class FakeDetectionModel(model.DetectionModel):
} }
return loss_dict return loss_dict
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True): def restore_map(self, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph. """Returns a map of variables to load from a foreign checkpoint.
Args: Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training. classification checkpoint for initialization prior to training.
Returns: Returns:
a callable which takes a tf.Session and does nothing. A dict mapping variable names to variables.
""" """
def restore(unused_sess): return {var.op.name: var for var in tf.global_variables()}
return
return restore
class TrainerTest(tf.test.TestCase): class TrainerTest(tf.test.TestCase):
......
...@@ -120,6 +120,7 @@ py_library( ...@@ -120,6 +120,7 @@ py_library(
"//tensorflow_models/object_detection/core:box_list", "//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor", "//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:matcher", "//tensorflow_models/object_detection/core:matcher",
"//tensorflow_models/object_detection/utils:shape_utils"
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment