Commit a4944a57 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Add Tensorflow Object Detection API. (#1561)

For details see our paper:
"Speed/accuracy trade-offs for modern convolutional object detectors."
Huang J, Rathod V, Sun C, Zhu M, Korattikara A, Fathi A, Fischer I,
Wojna Z, Song Y, Guadarrama S, Murphy K, CVPR 2017
https://arxiv.org/abs/1611.10012
parent 60c3ed2e
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.bipartite_matcher."""
import tensorflow as tf
from object_detection.matchers import bipartite_matcher
class GreedyBipartiteMatcherTest(tf.test.TestCase):
def test_get_expected_matches_when_all_rows_are_valid(self):
similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
num_valid_rows = 2
expected_match_results = [-1, 1, 0]
matcher = bipartite_matcher.GreedyBipartiteMatcher()
match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows)
with self.test_session() as sess:
match_results_out = sess.run(match._match_results)
self.assertAllEqual(match_results_out, expected_match_results)
def test_get_expected_matches_with_valid_rows_set_to_minus_one(self):
similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
num_valid_rows = -1
expected_match_results = [-1, 1, 0]
matcher = bipartite_matcher.GreedyBipartiteMatcher()
match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows)
with self.test_session() as sess:
match_results_out = sess.run(match._match_results)
self.assertAllEqual(match_results_out, expected_match_results)
def test_get_no_matches_with_zero_valid_rows(self):
similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
num_valid_rows = 0
expected_match_results = [-1, -1, -1]
matcher = bipartite_matcher.GreedyBipartiteMatcher()
match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows)
with self.test_session() as sess:
match_results_out = sess.run(match._match_results)
self.assertAllEqual(match_results_out, expected_match_results)
def test_get_expected_matches_with_only_one_valid_row(self):
similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
num_valid_rows = 1
expected_match_results = [-1, -1, 0]
matcher = bipartite_matcher.GreedyBipartiteMatcher()
match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows)
with self.test_session() as sess:
match_results_out = sess.run(match._match_results)
self.assertAllEqual(match_results_out, expected_match_results)
if __name__ == '__main__':
tf.test.main()
# Tensorflow Object Detection API: Meta-architectures.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "ssd_meta_arch",
srcs = ["ssd_meta_arch.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:variables_helper",
],
)
py_test(
name = "ssd_meta_arch_test",
srcs = ["ssd_meta_arch_test.py"],
deps = [
":ssd_meta_arch",
"//tensorflow",
"//tensorflow/python:training",
"//tensorflow_models/object_detection/core:anchor_generator",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:losses",
"//tensorflow_models/object_detection/core:post_processing",
"//tensorflow_models/object_detection/core:region_similarity_calculator",
"//tensorflow_models/object_detection/utils:test_utils",
],
)
py_library(
name = "faster_rcnn_meta_arch",
srcs = [
"faster_rcnn_meta_arch.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow_models/object_detection/core:balanced_positive_negative_sampler",
"//tensorflow_models/object_detection/core:box_list",
"//tensorflow_models/object_detection/core:box_list_ops",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/core:losses",
"//tensorflow_models/object_detection/core:model",
"//tensorflow_models/object_detection/core:post_processing",
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:variables_helper",
],
)
py_library(
name = "faster_rcnn_meta_arch_test_lib",
srcs = [
"faster_rcnn_meta_arch_test_lib.py",
],
deps = [
":faster_rcnn_meta_arch",
"//tensorflow",
"//tensorflow_models/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow_models/object_detection/builders:box_predictor_builder",
"//tensorflow_models/object_detection/builders:hyperparams_builder",
"//tensorflow_models/object_detection/builders:post_processing_builder",
"//tensorflow_models/object_detection/core:losses",
"//tensorflow_models/object_detection/protos:box_predictor_py_pb2",
"//tensorflow_models/object_detection/protos:hyperparams_py_pb2",
"//tensorflow_models/object_detection/protos:post_processing_py_pb2",
],
)
py_test(
name = "faster_rcnn_meta_arch_test",
srcs = ["faster_rcnn_meta_arch_test.py"],
deps = [
":faster_rcnn_meta_arch_test_lib",
],
)
py_library(
name = "rfcn_meta_arch",
srcs = ["rfcn_meta_arch.py"],
deps = [
":faster_rcnn_meta_arch",
"//tensorflow",
"//tensorflow_models/object_detection/core:box_predictor",
"//tensorflow_models/object_detection/utils:ops",
],
)
py_test(
name = "rfcn_meta_arch_test",
srcs = ["rfcn_meta_arch_test.py"],
deps = [
":faster_rcnn_meta_arch_test_lib",
":rfcn_meta_arch",
"//tensorflow",
],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Faster R-CNN meta-architecture definition.
General tensorflow implementation of Faster R-CNN detection models.
See Faster R-CNN: Ren, Shaoqing, et al.
"Faster R-CNN: Towards real-time object detection with region proposal
networks." Advances in neural information processing systems. 2015.
We allow for two modes: first_stage_only=True and first_stage_only=False. In
the former setting, all of the user facing methods (e.g., predict, postprocess,
loss) can be used as if the model consisted only of the RPN, returning class
agnostic proposals (these can be thought of as approximate detections with no
associated class information). In the latter setting, proposals are computed,
then passed through a second stage "box classifier" to yield (multi-class)
detections.
Implementations of Faster R-CNN models must define a new
FasterRCNNFeatureExtractor and override three methods: `preprocess`,
`_extract_proposal_features` (the first stage of the model), and
`_extract_box_classifier_features` (the second stage of the model). Optionally,
the `restore_fn` method can be overridden. See tests for an example.
A few important notes:
+ Batching conventions: We support batched inference and training where
all images within a batch have the same resolution. Batch sizes are determined
dynamically via the shape of the input tensors (rather than being specified
directly as, e.g., a model constructor).
A complication is that due to non-max suppression, we are not guaranteed to get
the same number of proposals from the first stage RPN (region proposal network)
for each image (though in practice, we should often get the same number of
proposals). For this reason we pad to a max number of proposals per image
within a batch. This `self.max_num_proposals` property is set to the
`first_stage_max_proposals` parameter at inference time and the
`second_stage_batch_size` at training time since we subsample the batch to
be sent through the box classifier during training.
For the second stage of the pipeline, we arrange the proposals for all images
within the batch along a single batch dimension. For example, the input to
_extract_box_classifier_features is a tensor of shape
`[total_num_proposals, crop_height, crop_width, depth]` where
total_num_proposals is batch_size * self.max_num_proposals. (And note that per
the above comment, a subset of these entries correspond to zero paddings.)
+ Coordinate representations:
Following the API (see model.DetectionModel definition), our outputs after
postprocessing operations are always normalized boxes however, internally, we
sometimes convert to absolute --- e.g. for loss computation. In particular,
anchors and proposal_boxes are both represented as absolute coordinates.
TODO: Support TPU implementations and sigmoid loss.
"""
from abc import abstractmethod
from functools import partial
import tensorflow as tf
from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import box_predictor
from object_detection.core import losses
from object_detection.core import model
from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import variables_helper
slim = tf.contrib.slim
class FasterRCNNFeatureExtractor(object):
"""Faster R-CNN Feature Extractor definition."""
def __init__(self,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of the
computation graph should be constructed.
first_stage_features_stride: Output stride of extracted RPN feature map.
reuse_weights: Whether to reuse variables. Default is None.
weight_decay: float weight decay for feature extractor (default: 0.0).
"""
self._is_training = is_training
self._first_stage_features_stride = first_stage_features_stride
self._reuse_weights = reuse_weights
self._weight_decay = weight_decay
@abstractmethod
def preprocess(self, resized_inputs):
"""Feature-extractor specific preprocessing (minus image resizing)."""
pass
def extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features.
This function is responsible for extracting feature maps from preprocessed
images. These features are used by the region proposal network (RPN) to
predict proposals.
Args:
preprocessed_inputs: A [batch, height, width, channels] float tensor
representing a batch of images.
scope: A scope name.
Returns:
rpn_feature_map: A tensor with shape [batch, height, width, depth]
"""
with tf.variable_scope(scope, values=[preprocessed_inputs]):
return self._extract_proposal_features(preprocessed_inputs, scope)
@abstractmethod
def _extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features, to be overridden."""
pass
def extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features.
Args:
proposal_feature_maps: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, crop_height, crop_width, depth]
representing the feature map cropped to each proposal.
scope: A scope name.
Returns:
proposal_classifier_features: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, height, width, depth]
representing box classifier features for each proposal.
"""
with tf.variable_scope(scope, values=[proposal_feature_maps]):
return self._extract_box_classifier_features(proposal_feature_maps, scope)
@abstractmethod
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features, to be overridden."""
pass
def restore_from_classification_checkpoint_fn(
self,
checkpoint_path,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph.
Args:
checkpoint_path: path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
"""
variables_to_restore = {}
for variable in tf.global_variables():
for scope_name in [first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope]:
if variable.op.name.startswith(scope_name):
var_name = variable.op.name.replace(scope_name + '/', '')
variables_to_restore[var_name] = variable
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
class FasterRCNNMetaArch(model.DetectionModel):
"""Faster R-CNN Meta-architecture definition."""
def __init__(self,
is_training,
num_classes,
image_resizer_fn,
feature_extractor,
first_stage_only,
first_stage_anchor_generator,
first_stage_atrous_rate,
first_stage_box_predictor_arg_scope,
first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth,
first_stage_minibatch_size,
first_stage_positive_balance_fraction,
first_stage_nms_score_threshold,
first_stage_nms_iou_threshold,
first_stage_max_proposals,
first_stage_localization_loss_weight,
first_stage_objectness_loss_weight,
initial_crop_size,
maxpool_kernel_size,
maxpool_stride,
second_stage_mask_rcnn_box_predictor,
second_stage_batch_size,
second_stage_balance_fraction,
second_stage_non_max_suppression_fn,
second_stage_score_conversion_fn,
second_stage_localization_loss_weight,
second_stage_classification_loss_weight,
hard_example_miner,
parallel_iterations=16):
"""FasterRCNNMetaArch Constructor.
Args:
is_training: A boolean indicating whether the training version of the
computation graph should be constructed.
num_classes: Number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
image_resizer_fn: A callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions.
See builders/image_resizer_builder.py.
feature_extractor: A FasterRCNNFeatureExtractor object.
first_stage_only: Whether to construct only the Region Proposal Network
(RPN) part of the model.
first_stage_anchor_generator: An anchor_generator.AnchorGenerator object
(note that currently we only support
grid_anchor_generator.GridAnchorGenerator objects)
first_stage_atrous_rate: A single integer indicating the atrous rate for
the single convolution op which is applied to the `rpn_features_to_crop`
tensor to obtain a tensor to be used for box prediction. Some feature
extractors optionally allow for producing feature maps computed at
denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d,
separable_conv2d and fully_connected ops for the RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op
just prior to RPN box predictions.
first_stage_minibatch_size: The "batch size" to use for computing the
objectness and location loss of the region proposal network. This
"batch size" refers to the number of anchors selected as contributing
to the loss function for any given image within the image batch and is
only called "batch_size" due to terminology from the Faster R-CNN paper.
first_stage_positive_balance_fraction: Fraction of positive examples
per image for the RPN. The recommended value for Faster RCNN is 0.5.
first_stage_nms_score_threshold: Score threshold for non max suppression
for the Region Proposal Network (RPN). This value is expected to be in
[0, 1] as it is applied directly after a softmax transformation. The
recommended value for Faster R-CNN is 0.
first_stage_nms_iou_threshold: The Intersection Over Union (IOU) threshold
for performing Non-Max Suppression (NMS) on the boxes predicted by the
Region Proposal Network (RPN).
first_stage_max_proposals: Maximum number of boxes to retain after
performing Non-Max Suppression (NMS) on the boxes predicted by the
Region Proposal Network (RPN).
first_stage_localization_loss_weight: A float
first_stage_objectness_loss_weight: A float
initial_crop_size: A single integer indicating the output size
(width and height are set to be the same) of the initial bilinear
interpolation based cropping during ROI pooling.
maxpool_kernel_size: A single integer indicating the kernel size of the
max pool op on the cropped feature map during ROI pooling.
maxpool_stride: A single integer indicating the stride of the max pool
op on the cropped feature map during ROI pooling.
second_stage_mask_rcnn_box_predictor: Mask R-CNN box predictor to use for
the second stage.
second_stage_batch_size: The batch size used for computing the
classification and refined location loss of the box classifier. This
"batch size" refers to the number of proposals selected as contributing
to the loss function for any given image within the image batch and is
only called "batch_size" due to terminology from the Faster R-CNN paper.
second_stage_balance_fraction: Fraction of positive examples to use
per image for the box classifier. The recommended value for Faster RCNN
is 0.25.
second_stage_non_max_suppression_fn: batch_multiclass_non_max_suppression
callable that takes `boxes`, `scores`, optional `clip_window` and
optional (kwarg) `mask` inputs (with all other inputs already set)
and returns a dictionary containing tensors with keys:
`detection_boxes`, `detection_scores`, `detection_classes`,
`num_detections`, and (optionally) `detection_masks`. See
`post_processing.batch_multiclass_non_max_suppression` for the type and
shape of these tensors.
second_stage_score_conversion_fn: Callable elementwise nonlinearity
(that takes tensors as inputs and returns tensors). This is usually
used to convert logits to probabilities.
second_stage_localization_loss_weight: A float
second_stage_classification_loss_weight: A float
hard_example_miner: A losses.HardExampleMiner object (can be None).
parallel_iterations: (Optional) The number of iterations allowed to run
in parallel for calls to tf.map_fn.
Raises:
ValueError: If `second_stage_batch_size` > `first_stage_max_proposals`
ValueError: If first_stage_anchor_generator is not of type
grid_anchor_generator.GridAnchorGenerator.
"""
super(FasterRCNNMetaArch, self).__init__(num_classes=num_classes)
if second_stage_batch_size > first_stage_max_proposals:
raise ValueError('second_stage_batch_size should be no greater than '
'first_stage_max_proposals.')
if not isinstance(first_stage_anchor_generator,
grid_anchor_generator.GridAnchorGenerator):
raise ValueError('first_stage_anchor_generator must be of type '
'grid_anchor_generator.GridAnchorGenerator.')
self._is_training = is_training
self._image_resizer_fn = image_resizer_fn
self._feature_extractor = feature_extractor
self._first_stage_only = first_stage_only
# The first class is reserved as background.
unmatched_cls_target = tf.constant(
[1] + self._num_classes * [0], dtype=tf.float32)
self._proposal_target_assigner = target_assigner.create_target_assigner(
'FasterRCNN', 'proposal')
self._detector_target_assigner = target_assigner.create_target_assigner(
'FasterRCNN', 'detection', unmatched_cls_target=unmatched_cls_target)
# Both proposal and detector target assigners use the same box coder
self._box_coder = self._proposal_target_assigner.box_coder
# (First stage) Region proposal network parameters
self._first_stage_anchor_generator = first_stage_anchor_generator
self._first_stage_atrous_rate = first_stage_atrous_rate
self._first_stage_box_predictor_arg_scope = (
first_stage_box_predictor_arg_scope)
self._first_stage_box_predictor_kernel_size = (
first_stage_box_predictor_kernel_size)
self._first_stage_box_predictor_depth = first_stage_box_predictor_depth
self._first_stage_minibatch_size = first_stage_minibatch_size
self._first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=first_stage_positive_balance_fraction)
self._first_stage_box_predictor = box_predictor.ConvolutionalBoxPredictor(
self._is_training, num_classes=1,
conv_hyperparams=self._first_stage_box_predictor_arg_scope,
min_depth=0, max_depth=0, num_layers_before_predictor=0,
use_dropout=False, dropout_keep_prob=1.0, kernel_size=1,
box_code_size=self._box_coder.code_size)
self._first_stage_nms_score_threshold = first_stage_nms_score_threshold
self._first_stage_nms_iou_threshold = first_stage_nms_iou_threshold
self._first_stage_max_proposals = first_stage_max_proposals
self._first_stage_localization_loss = (
losses.WeightedSmoothL1LocalizationLoss(anchorwise_output=True))
self._first_stage_objectness_loss = (
losses.WeightedSoftmaxClassificationLoss(anchorwise_output=True))
self._first_stage_loc_loss_weight = first_stage_localization_loss_weight
self._first_stage_obj_loss_weight = first_stage_objectness_loss_weight
# Per-region cropping parameters
self._initial_crop_size = initial_crop_size
self._maxpool_kernel_size = maxpool_kernel_size
self._maxpool_stride = maxpool_stride
self._mask_rcnn_box_predictor = second_stage_mask_rcnn_box_predictor
self._second_stage_batch_size = second_stage_batch_size
self._second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=second_stage_balance_fraction)
self._second_stage_nms_fn = second_stage_non_max_suppression_fn
self._second_stage_score_conversion_fn = second_stage_score_conversion_fn
self._second_stage_localization_loss = (
losses.WeightedSmoothL1LocalizationLoss(anchorwise_output=True))
self._second_stage_classification_loss = (
losses.WeightedSoftmaxClassificationLoss(anchorwise_output=True))
self._second_stage_loc_loss_weight = second_stage_localization_loss_weight
self._second_stage_cls_loss_weight = second_stage_classification_loss_weight
self._hard_example_miner = hard_example_miner
self._parallel_iterations = parallel_iterations
@property
def first_stage_feature_extractor_scope(self):
return 'FirstStageFeatureExtractor'
@property
def second_stage_feature_extractor_scope(self):
return 'SecondStageFeatureExtractor'
@property
def first_stage_box_predictor_scope(self):
return 'FirstStageBoxPredictor'
@property
def second_stage_box_predictor_scope(self):
return 'SecondStageBoxPredictor'
@property
def max_num_proposals(self):
"""Max number of proposals (to pad to) for each image in the input batch.
At training time, this is set to be the `second_stage_batch_size` if hard
example miner is not configured, else it is set to
`first_stage_max_proposals`. At inference time, this is always set to
`first_stage_max_proposals`.
Returns:
A positive integer.
"""
if self._is_training and not self._hard_example_miner:
return self._second_stage_batch_size
return self._first_stage_max_proposals
def preprocess(self, inputs):
"""Feature-extractor specific preprocessing.
See base class.
For Faster R-CNN, we perform image resizing in the base class --- each
class subclassing FasterRCNNMetaArch is responsible for any additional
preprocessing (e.g., scaling pixel values to be in [-1, 1]).
Args:
inputs: a [batch, height_in, width_in, channels] float tensor representing
a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: a [batch, height_out, width_out, channels] float
tensor representing a batch of images.
Raises:
ValueError: if inputs tensor does not have type tf.float32
"""
if inputs.dtype is not tf.float32:
raise ValueError('`preprocess` expects a tf.float32 tensor')
with tf.name_scope('Preprocessor'):
resized_inputs = tf.map_fn(self._image_resizer_fn,
elems=inputs,
dtype=tf.float32,
parallel_iterations=self._parallel_iterations)
return self._feature_extractor.preprocess(resized_inputs)
def predict(self, preprocessed_inputs):
"""Predicts unpostprocessed tensors from input tensor.
This function takes an input batch of images and runs it through the
forward pass of the network to yield "raw" un-postprocessed predictions.
If `first_stage_only` is True, this function only returns first stage
RPN predictions (un-postprocessed). Otherwise it returns both
first stage RPN predictions as well as second stage box classifier
predictions.
Other remarks:
+ Anchor pruning vs. clipping: following the recommendation of the Faster
R-CNN paper, we prune anchors that venture outside the image window at
training time and clip anchors to the image window at inference time.
+ Proposal padding: as described at the top of the file, proposals are
padded to self._max_num_proposals and flattened so that proposals from all
images within the input batch are arranged along the same batch dimension.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) rpn_box_predictor_features: A 4-D float32 tensor with shape
[batch_size, height, width, depth] to be used for predicting proposal
boxes and corresponding objectness scores.
2) rpn_features_to_crop: A 4-D float32 tensor with shape
[batch_size, height, width, depth] representing image features to crop
using the proposal boxes predicted by the RPN.
3) image_shape: a 1-D tensor of shape [4] representing the input
image shape.
4) rpn_box_encodings: 3-D float tensor of shape
[batch_size, num_anchors, self._box_coder.code_size] containing
predicted boxes.
5) rpn_objectness_predictions_with_background: 3-D float tensor of shape
[batch_size, num_anchors, 2] containing class
predictions (logits) for each of the anchors. Note that this
tensor *includes* background class predictions (at class index 0).
6) anchors: A 2-D tensor of shape [num_anchors, 4] representing anchors
for the first stage RPN (in absolute coordinates). Note that
`num_anchors` can differ depending on whether the model is created in
training or inference mode.
(and if first_stage_only=False):
7) refined_box_encodings: a 3-D tensor with shape
[total_num_proposals, num_classes, 4] representing predicted
(final) refined box encodings, where
total_num_proposals=batch_size*self._max_num_proposals
8) class_predictions_with_background: a 3-D tensor with shape
[total_num_proposals, num_classes + 1] containing class
predictions (logits) for each of the anchors, where
total_num_proposals=batch_size*self._max_num_proposals.
Note that this tensor *includes* background class predictions
(at class index 0).
9) num_proposals: An int32 tensor of shape [batch_size] representing the
number of proposals generated by the RPN. `num_proposals` allows us
to keep track of which entries are to be treated as zero paddings and
which are not since we always pad the number of proposals to be
`self.max_num_proposals` for each image.
10) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes (in absolute coordinates).
11) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
"""
(rpn_box_predictor_features, rpn_features_to_crop, anchors_boxlist,
image_shape) = self._extract_rpn_feature_maps(preprocessed_inputs)
(rpn_box_encodings, rpn_objectness_predictions_with_background
) = self._predict_rpn_proposals(rpn_box_predictor_features)
# The Faster R-CNN paper recommends pruning anchors that venture outside
# the image window at training time and clipping at inference time.
clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]]))
if self._is_training:
(rpn_box_encodings, rpn_objectness_predictions_with_background,
anchors_boxlist) = self._remove_invalid_anchors_and_predictions(
rpn_box_encodings, rpn_objectness_predictions_with_background,
anchors_boxlist, clip_window)
else:
anchors_boxlist = box_list_ops.clip_to_window(
anchors_boxlist, clip_window)
anchors = anchors_boxlist.get()
prediction_dict = {
'rpn_box_predictor_features': rpn_box_predictor_features,
'rpn_features_to_crop': rpn_features_to_crop,
'image_shape': image_shape,
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'anchors': anchors
}
if not self._first_stage_only:
prediction_dict.update(self._predict_second_stage(
rpn_box_encodings,
rpn_objectness_predictions_with_background,
rpn_features_to_crop,
anchors, image_shape))
return prediction_dict
def _predict_second_stage(self, rpn_box_encodings,
rpn_objectness_predictions_with_background,
rpn_features_to_crop,
anchors,
image_shape):
"""Predicts the output tensors from second stage of Faster R-CNN.
Args:
rpn_box_encodings: 4-D float tensor of shape
[batch_size, num_valid_anchors, self._box_coder.code_size] containing
predicted boxes.
rpn_objectness_predictions_with_background: 2-D float tensor of shape
[batch_size, num_valid_anchors, 2] containing class
predictions (logits) for each of the anchors. Note that this
tensor *includes* background class predictions (at class index 0).
rpn_features_to_crop: A 4-D float32 tensor with shape
[batch_size, height, width, depth] representing image features to crop
using the proposal boxes predicted by the RPN.
anchors: 2-D float tensor of shape
[num_anchors, self._box_coder.code_size].
image_shape: A 1D int32 tensors of size [4] containing the image shape.
Returns:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape
[total_num_proposals, num_classes, 4] representing predicted
(final) refined box encodings, where
total_num_proposals=batch_size*self._max_num_proposals
2) class_predictions_with_background: a 3-D tensor with shape
[total_num_proposals, num_classes + 1] containing class
predictions (logits) for each of the anchors, where
total_num_proposals=batch_size*self._max_num_proposals.
Note that this tensor *includes* background class predictions
(at class index 0).
3) num_proposals: An int32 tensor of shape [batch_size] representing the
number of proposals generated by the RPN. `num_proposals` allows us
to keep track of which entries are to be treated as zero paddings and
which are not since we always pad the number of proposals to be
`self.max_num_proposals` for each image.
4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes (in absolute coordinates).
5) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
"""
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
rpn_box_encodings, rpn_objectness_predictions_with_background,
anchors, image_shape)
flattened_proposal_feature_maps = (
self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, proposal_boxes_normalized))
box_classifier_features = (
self._feature_extractor.extract_box_classifier_features(
flattened_proposal_feature_maps,
scope=self.second_stage_feature_extractor_scope))
box_predictions = self._mask_rcnn_box_predictor.predict(
box_classifier_features,
num_predictions_per_location=1,
scope=self.second_stage_box_predictor_scope)
refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1)
class_predictions_with_background = tf.squeeze(box_predictions[
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], axis=1)
absolute_proposal_boxes = ops.normalized_to_image_coordinates(
proposal_boxes_normalized, image_shape, self._parallel_iterations)
prediction_dict = {
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background':
class_predictions_with_background,
'num_proposals': num_proposals,
'proposal_boxes': absolute_proposal_boxes,
}
return prediction_dict
def _extract_rpn_feature_maps(self, preprocessed_inputs):
"""Extracts RPN features.
This function extracts two feature maps: a feature map to be directly
fed to a box predictor (to predict location and objectness scores for
proposals) and a feature map from which to crop regions which will then
be sent to the second stage box classifier.
Args:
preprocessed_inputs: a [batch, height, width, channels] image tensor.
Returns:
rpn_box_predictor_features: A 4-D float32 tensor with shape
[batch, height, width, depth] to be used for predicting proposal boxes
and corresponding objectness scores.
rpn_features_to_crop: A 4-D float32 tensor with shape
[batch, height, width, depth] representing image features to crop using
the proposals boxes.
anchors: A BoxList representing anchors (for the RPN) in
absolute coordinates.
image_shape: A 1-D tensor representing the input image shape.
"""
image_shape = tf.shape(preprocessed_inputs)
rpn_features_to_crop = self._feature_extractor.extract_proposal_features(
preprocessed_inputs, scope=self.first_stage_feature_extractor_scope)
feature_map_shape = tf.shape(rpn_features_to_crop)
anchors = self._first_stage_anchor_generator.generate(
[(feature_map_shape[1], feature_map_shape[2])])
with slim.arg_scope(self._first_stage_box_predictor_arg_scope):
kernel_size = self._first_stage_box_predictor_kernel_size
rpn_box_predictor_features = slim.conv2d(
rpn_features_to_crop,
self._first_stage_box_predictor_depth,
kernel_size=[kernel_size, kernel_size],
rate=self._first_stage_atrous_rate,
activation_fn=tf.nn.relu6)
return (rpn_box_predictor_features, rpn_features_to_crop,
anchors, image_shape)
def _predict_rpn_proposals(self, rpn_box_predictor_features):
"""Adds box predictors to RPN feature map to predict proposals.
Note resulting tensors will not have been postprocessed.
Args:
rpn_box_predictor_features: A 4-D float32 tensor with shape
[batch, height, width, depth] to be used for predicting proposal boxes
and corresponding objectness scores.
Returns:
box_encodings: 3-D float tensor of shape
[batch_size, num_anchors, self._box_coder.code_size] containing
predicted boxes.
objectness_predictions_with_background: 3-D float tensor of shape
[batch_size, num_anchors, 2] containing class
predictions (logits) for each of the anchors. Note that this
tensor *includes* background class predictions (at class index 0).
Raises:
RuntimeError: if the anchor generator generates anchors corresponding to
multiple feature maps. We currently assume that a single feature map
is generated for the RPN.
"""
num_anchors_per_location = (
self._first_stage_anchor_generator.num_anchors_per_location())
if len(num_anchors_per_location) != 1:
raise RuntimeError('anchor_generator is expected to generate anchors '
'corresponding to a single feature map.')
box_predictions = self._first_stage_box_predictor.predict(
rpn_box_predictor_features,
num_anchors_per_location[0],
scope=self.first_stage_box_predictor_scope)
box_encodings = box_predictions[box_predictor.BOX_ENCODINGS]
objectness_predictions_with_background = box_predictions[
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND]
return (tf.squeeze(box_encodings, axis=2),
objectness_predictions_with_background)
def _remove_invalid_anchors_and_predictions(
self,
box_encodings,
objectness_predictions_with_background,
anchors_boxlist,
clip_window):
"""Removes anchors that (partially) fall outside an image.
Also removes associated box encodings and objectness predictions.
Args:
box_encodings: 3-D float tensor of shape
[batch_size, num_anchors, self._box_coder.code_size] containing
predicted boxes.
objectness_predictions_with_background: 3-D float tensor of shape
[batch_size, num_anchors, 2] containing class
predictions (logits) for each of the anchors. Note that this
tensor *includes* background class predictions (at class index 0).
anchors_boxlist: A BoxList representing num_anchors anchors (for the RPN)
in absolute coordinates.
clip_window: a 1-D tensor representing the [ymin, xmin, ymax, xmax]
extent of the window to clip/prune to.
Returns:
box_encodings: 4-D float tensor of shape
[batch_size, num_valid_anchors, self._box_coder.code_size] containing
predicted boxes, where num_valid_anchors <= num_anchors
objectness_predictions_with_background: 2-D float tensor of shape
[batch_size, num_valid_anchors, 2] containing class
predictions (logits) for each of the anchors, where
num_valid_anchors <= num_anchors. Note that this
tensor *includes* background class predictions (at class index 0).
anchors: A BoxList representing num_valid_anchors anchors (for the RPN) in
absolute coordinates.
"""
pruned_anchors_boxlist, keep_indices = box_list_ops.prune_outside_window(
anchors_boxlist, clip_window)
def _batch_gather_kept_indices(predictions_tensor):
return tf.map_fn(
partial(tf.gather, indices=keep_indices),
elems=predictions_tensor,
dtype=tf.float32,
parallel_iterations=self._parallel_iterations,
back_prop=True)
return (_batch_gather_kept_indices(box_encodings),
_batch_gather_kept_indices(objectness_predictions_with_background),
pruned_anchors_boxlist)
def _flatten_first_two_dimensions(self, inputs):
"""Flattens `K-d` tensor along batch dimension to be a `(K-1)-d` tensor.
Converts `inputs` with shape [A, B, ..., depth] into a tensor of shape
[A * B, ..., depth].
Args:
inputs: A float tensor with shape [A, B, ..., depth]. Note that the first
two and last dimensions must be statically defined.
Returns:
A float tensor with shape [A * B, ..., depth] (where the first and last
dimension are statically defined.
"""
inputs_shape = inputs.get_shape().as_list()
flattened_shape = tf.concat([
[inputs_shape[0]*inputs_shape[1]], tf.shape(inputs)[2:-1],
[inputs_shape[-1]]], 0)
return tf.reshape(inputs, flattened_shape)
def postprocess(self, prediction_dict):
"""Convert prediction tensors to final detections.
This function converts raw predictions tensors to final detection results.
See base class for output format conventions. Note also that by default,
scores are to be interpreted as logits, but if a score_converter is used,
then scores are remapped (and may thus have a different interpretation).
If first_stage_only=True, the returned results represent proposals from the
first stage RPN and are padded to have self.max_num_proposals for each
image; otherwise, the results can be interpreted as multiclass detections
from the full two-stage model and are padded to self._max_detections.
Args:
prediction_dict: a dictionary holding prediction tensors (see the
documentation for the predict method. If first_stage_only=True, we
expect prediction_dict to contain `rpn_box_encodings`,
`rpn_objectness_predictions_with_background`, `rpn_features_to_crop`,
`image_shape`, and `anchors` fields. Otherwise we expect
prediction_dict to additionally contain `refined_box_encodings`,
`class_predictions_with_background`, `num_proposals`,
`proposal_boxes` and, optionally, `mask_predictions` fields.
Returns:
detections: a dictionary containing the following fields
detection_boxes: [batch, max_detection, 4]
detection_scores: [batch, max_detections]
detection_classes: [batch, max_detections]
(this entry is only created if rpn_mode=False)
num_detections: [batch]
"""
with tf.name_scope('FirstStagePostprocessor'):
image_shape = prediction_dict['image_shape']
if self._first_stage_only:
proposal_boxes, proposal_scores, num_proposals = self._postprocess_rpn(
prediction_dict['rpn_box_encodings'],
prediction_dict['rpn_objectness_predictions_with_background'],
prediction_dict['anchors'],
image_shape)
return {
'detection_boxes': proposal_boxes,
'detection_scores': proposal_scores,
'num_detections': num_proposals
}
with tf.name_scope('SecondStagePostprocessor'):
mask_predictions = prediction_dict.get(box_predictor.MASK_PREDICTIONS)
detections_dict = self._postprocess_box_classifier(
prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'],
prediction_dict['proposal_boxes'],
prediction_dict['num_proposals'],
image_shape,
mask_predictions=mask_predictions)
return detections_dict
def _postprocess_rpn(self,
rpn_box_encodings_batch,
rpn_objectness_predictions_with_background_batch,
anchors,
image_shape):
"""Converts first stage prediction tensors from the RPN to proposals.
This function decodes the raw RPN predictions, runs non-max suppression
on the result.
Note that the behavior of this function is slightly modified during
training --- specifically, we stop the gradient from passing through the
proposal boxes and we only return a balanced sampled subset of proposals
with size `second_stage_batch_size`.
Args:
rpn_box_encodings_batch: A 3-D float32 tensor of shape
[batch_size, num_anchors, self._box_coder.code_size] containing
predicted proposal box encodings.
rpn_objectness_predictions_with_background_batch: A 3-D float tensor of
shape [batch_size, num_anchors, 2] containing objectness predictions
(logits) for each of the anchors with 0 corresponding to background
and 1 corresponding to object.
anchors: A 2-D tensor of shape [num_anchors, 4] representing anchors
for the first stage RPN. Note that `num_anchors` can differ depending
on whether the model is created in training or inference mode.
image_shape: A 1-D tensor representing the input image shape.
Returns:
proposal_boxes: A float tensor with shape
[batch_size, max_num_proposals, 4] representing the (potentially zero
padded) proposal boxes for all images in the batch. These boxes are
represented as normalized coordinates.
proposal_scores: A float tensor with shape
[batch_size, max_num_proposals] representing the (potentially zero
padded) proposal objectness scores for all images in the batch.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
"""
clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]]))
if self._is_training:
(groundtruth_boxlists, groundtruth_classes_with_background_list
) = self._format_groundtruth_data(image_shape)
proposal_boxes_list = []
proposal_scores_list = []
num_proposals_list = []
for (batch_index,
(rpn_box_encodings,
rpn_objectness_predictions_with_background)) in enumerate(zip(
tf.unstack(rpn_box_encodings_batch),
tf.unstack(rpn_objectness_predictions_with_background_batch))):
decoded_boxes = self._box_coder.decode(
rpn_box_encodings, box_list.BoxList(anchors))
objectness_scores = tf.unstack(
tf.nn.softmax(rpn_objectness_predictions_with_background), axis=1)[1]
proposal_boxlist = post_processing.multiclass_non_max_suppression(
tf.expand_dims(decoded_boxes.get(), 1),
tf.expand_dims(objectness_scores, 1),
self._first_stage_nms_score_threshold,
self._first_stage_nms_iou_threshold, self._first_stage_max_proposals,
clip_window=clip_window)
if self._is_training:
proposal_boxlist.set(tf.stop_gradient(proposal_boxlist.get()))
if not self._hard_example_miner:
proposal_boxlist = self._sample_box_classifier_minibatch(
proposal_boxlist, groundtruth_boxlists[batch_index],
groundtruth_classes_with_background_list[batch_index])
normalized_proposals = box_list_ops.to_normalized_coordinates(
proposal_boxlist, image_shape[1], image_shape[2],
check_range=False)
# pad proposals to max_num_proposals
padded_proposals = box_list_ops.pad_or_clip_box_list(
normalized_proposals, num_boxes=self.max_num_proposals)
proposal_boxes_list.append(padded_proposals.get())
proposal_scores_list.append(
padded_proposals.get_field(fields.BoxListFields.scores))
num_proposals_list.append(tf.minimum(normalized_proposals.num_boxes(),
self.max_num_proposals))
return (tf.stack(proposal_boxes_list), tf.stack(proposal_scores_list),
tf.stack(num_proposals_list))
def _format_groundtruth_data(self, image_shape):
"""Helper function for preparing groundtruth data for target assignment.
In order to be consistent with the model.DetectionModel interface,
groundtruth boxes are specified in normalized coordinates and classes are
specified as label indices with no assumed background category. To prepare
for target assignment, we:
1) convert boxes to absolute coordinates,
2) add a background class at class index 0
Args:
image_shape: A 1-D int32 tensor of shape [4] representing the shape of the
input image batch.
Returns:
groundtruth_boxlists: A list of BoxLists containing (absolute) coordinates
of the groundtruth boxes.
groundtruth_classes_with_background_list: A list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes+1] containing the
class targets with the 0th index assumed to map to the background class.
"""
groundtruth_boxlists = [
box_list_ops.to_absolute_coordinates(
box_list.BoxList(boxes), image_shape[1], image_shape[2])
for boxes in self.groundtruth_lists(fields.BoxListFields.boxes)]
groundtruth_classes_with_background_list = [
tf.to_float(
tf.pad(one_hot_encoding, [[0, 0], [1, 0]], mode='CONSTANT'))
for one_hot_encoding in self.groundtruth_lists(
fields.BoxListFields.classes)]
return groundtruth_boxlists, groundtruth_classes_with_background_list
def _sample_box_classifier_minibatch(self,
proposal_boxlist,
groundtruth_boxlist,
groundtruth_classes_with_background):
"""Samples a mini-batch of proposals to be sent to the box classifier.
Helper function for self._postprocess_rpn.
Args:
proposal_boxlist: A BoxList containing K proposal boxes in absolute
coordinates.
groundtruth_boxlist: A Boxlist containing N groundtruth object boxes in
absolute coordinates.
groundtruth_classes_with_background: A tensor with shape
`[N, self.num_classes + 1]` representing groundtruth classes. The
classes are assumed to be k-hot encoded, and include background as the
zero-th class.
Returns:
a BoxList contained sampled proposals.
"""
(cls_targets, cls_weights, _, _, _) = self._detector_target_assigner.assign(
proposal_boxlist, groundtruth_boxlist,
groundtruth_classes_with_background)
# Selects all boxes as candidates if none of them is selected according
# to cls_weights. This could happen as boxes within certain IOU ranges
# are ignored. If triggered, the selected boxes will still be ignored
# during loss computation.
cls_weights += tf.to_float(tf.equal(tf.reduce_sum(cls_weights), 0))
positive_indicator = tf.greater(tf.argmax(cls_targets, axis=1), 0)
sampled_indices = self._second_stage_sampler.subsample(
tf.cast(cls_weights, tf.bool),
self._second_stage_batch_size,
positive_indicator)
return box_list_ops.boolean_mask(proposal_boxlist, sampled_indices)
def _compute_second_stage_input_feature_maps(self, features_to_crop,
proposal_boxes_normalized):
"""Crops to a set of proposals from the feature map for a batch of images.
Helper function for self._postprocess_rpn. This function calls
`tf.image.crop_and_resize` to create the feature map to be passed to the
second stage box classifier for each proposal.
Args:
features_to_crop: A float32 tensor with shape
[batch_size, height, width, depth]
proposal_boxes_normalized: A float32 tensor with shape [batch_size,
num_proposals, box_code_size] containing proposal boxes in
normalized coordinates.
Returns:
A float32 tensor with shape [K, new_height, new_width, depth].
"""
def get_box_inds(proposals):
proposals_shape = proposals.get_shape().as_list()
if any(dim is None for dim in proposals_shape):
proposals_shape = tf.shape(proposals)
ones_mat = tf.ones(proposals_shape[:2], dtype=tf.int32)
multiplier = tf.expand_dims(
tf.range(start=0, limit=proposals_shape[0]), 1)
return tf.reshape(ones_mat * multiplier, [-1])
cropped_regions = tf.image.crop_and_resize(
features_to_crop,
self._flatten_first_two_dimensions(proposal_boxes_normalized),
get_box_inds(proposal_boxes_normalized),
(self._initial_crop_size, self._initial_crop_size))
return slim.max_pool2d(
cropped_regions,
[self._maxpool_kernel_size, self._maxpool_kernel_size],
stride=self._maxpool_stride)
def _postprocess_box_classifier(self,
refined_box_encodings,
class_predictions_with_background,
proposal_boxes,
num_proposals,
image_shape,
mask_predictions=None,
mask_threshold=0.5):
"""Converts predictions from the second stage box classifier to detections.
Args:
refined_box_encodings: a 3-D tensor with shape
[total_num_padded_proposals, num_classes, 4] representing predicted
(final) refined box encodings.
class_predictions_with_background: a 3-D tensor with shape
[total_num_padded_proposals, num_classes + 1] containing class
predictions (logits) for each of the proposals. Note that this tensor
*includes* background class predictions (at class index 0).
proposal_boxes: [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
image_shape: a 1-D tensor representing the input image shape.
mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
mask_threshold: a scalar threshold determining which mask values are
rounded to 0 or 1.
Returns:
A dictionary containing:
`detection_boxes`: [batch, max_detection, 4]
`detection_scores`: [batch, max_detections]
`detection_classes`: [batch, max_detections]
`num_detections`: [batch]
`detection_masks`:
(optional) [batch, max_detections, mask_height, mask_width]
"""
refined_box_encodings_batch = tf.reshape(refined_box_encodings,
[-1, self.max_num_proposals,
self.num_classes,
self._box_coder.code_size])
class_predictions_with_background_batch = tf.reshape(
class_predictions_with_background,
[-1, self.max_num_proposals, self.num_classes + 1]
)
refined_decoded_boxes_batch = self._batch_decode_refined_boxes(
refined_box_encodings_batch, proposal_boxes)
class_predictions_with_background_batch = (
self._second_stage_score_conversion_fn(
class_predictions_with_background_batch))
class_predictions_batch = tf.reshape(
tf.slice(class_predictions_with_background_batch,
[0, 0, 1], [-1, -1, -1]),
[-1, self.max_num_proposals, self.num_classes])
clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]]))
mask_predictions_batch = None
if mask_predictions is not None:
mask_height = mask_predictions.shape[2].value
mask_width = mask_predictions.shape[3].value
mask_predictions_batch = tf.reshape(
mask_predictions, [-1, self.max_num_proposals,
self.num_classes, mask_height, mask_width])
detections = self._second_stage_nms_fn(
refined_decoded_boxes_batch,
class_predictions_batch,
clip_window=clip_window,
change_coordinate_frame=True,
num_valid_boxes=num_proposals,
masks=mask_predictions_batch)
if mask_predictions is not None:
detections['detection_masks'] = tf.to_float(
tf.greater_equal(detections['detection_masks'], mask_threshold))
return detections
def _batch_decode_refined_boxes(self, refined_box_encodings, proposal_boxes):
"""Decode tensor of refined box encodings.
Args:
refined_box_encodings: a 3-D tensor with shape
[batch_size, max_num_proposals, num_classes, self._box_coder.code_size]
representing predicted (final) refined box encodings.
proposal_boxes: [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes.
Returns:
refined_box_predictions: a [batch_size, max_num_proposals, num_classes, 4]
float tensor representing (padded) refined bounding box predictions
(for each image in batch, proposal and class).
"""
tiled_proposal_boxes = tf.tile(
tf.expand_dims(proposal_boxes, 2), [1, 1, self.num_classes, 1])
tiled_proposals_boxlist = box_list.BoxList(
tf.reshape(tiled_proposal_boxes, [-1, 4]))
decoded_boxes = self._box_coder.decode(
tf.reshape(refined_box_encodings, [-1, self._box_coder.code_size]),
tiled_proposals_boxlist)
return tf.reshape(decoded_boxes.get(),
[-1, self.max_num_proposals, self.num_classes, 4])
def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors given prediction tensors.
If first_stage_only=True, only RPN related losses are computed (i.e.,
`rpn_localization_loss` and `rpn_objectness_loss`). Otherwise all
losses are computed.
Args:
prediction_dict: a dictionary holding prediction tensors (see the
documentation for the predict method. If first_stage_only=True, we
expect prediction_dict to contain `rpn_box_encodings`,
`rpn_objectness_predictions_with_background`, `rpn_features_to_crop`,
`image_shape`, and `anchors` fields. Otherwise we expect
prediction_dict to additionally contain `refined_box_encodings`,
`class_predictions_with_background`, `num_proposals`, and
`proposal_boxes` fields.
scope: Optional scope name.
Returns:
a dictionary mapping loss keys (`first_stage_localization_loss`,
`first_stage_objectness_loss`, 'second_stage_localization_loss',
'second_stage_classification_loss') to scalar tensors representing
corresponding loss values.
"""
with tf.name_scope(scope, 'Loss', prediction_dict.values()):
(groundtruth_boxlists, groundtruth_classes_with_background_list
) = self._format_groundtruth_data(prediction_dict['image_shape'])
loss_dict = self._loss_rpn(
prediction_dict['rpn_box_encodings'],
prediction_dict['rpn_objectness_predictions_with_background'],
prediction_dict['anchors'],
groundtruth_boxlists,
groundtruth_classes_with_background_list)
if not self._first_stage_only:
loss_dict.update(
self._loss_box_classifier(
prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'],
prediction_dict['proposal_boxes'],
prediction_dict['num_proposals'],
groundtruth_boxlists,
groundtruth_classes_with_background_list))
return loss_dict
def _loss_rpn(self,
rpn_box_encodings,
rpn_objectness_predictions_with_background,
anchors,
groundtruth_boxlists,
groundtruth_classes_with_background_list):
"""Computes scalar RPN loss tensors.
Uses self._proposal_target_assigner to obtain regression and classification
targets for the first stage RPN, samples a "minibatch" of anchors to
participate in the loss computation, and returns the RPN losses.
Args:
rpn_box_encodings: A 4-D float tensor of shape
[batch_size, num_anchors, self._box_coder.code_size] containing
predicted proposal box encodings.
rpn_objectness_predictions_with_background: A 2-D float tensor of shape
[batch_size, num_anchors, 2] containing objectness predictions
(logits) for each of the anchors with 0 corresponding to background
and 1 corresponding to object.
anchors: A 2-D tensor of shape [num_anchors, 4] representing anchors
for the first stage RPN. Note that `num_anchors` can differ depending
on whether the model is created in training or inference mode.
groundtruth_boxlists: A list of BoxLists containing coordinates of the
groundtruth boxes.
groundtruth_classes_with_background_list: A list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes+1] containing the
class targets with the 0th index assumed to map to the background class.
Returns:
a dictionary mapping loss keys (`first_stage_localization_loss`,
`first_stage_objectness_loss`) to scalar tensors representing
corresponding loss values.
"""
with tf.name_scope('RPNLoss'):
(batch_cls_targets, batch_cls_weights, batch_reg_targets,
batch_reg_weights, _) = target_assigner.batch_assign_targets(
self._proposal_target_assigner, box_list.BoxList(anchors),
groundtruth_boxlists, len(groundtruth_boxlists)*[None])
batch_cls_targets = tf.squeeze(batch_cls_targets, axis=2)
def _minibatch_subsample_fn(inputs):
cls_targets, cls_weights = inputs
return self._first_stage_sampler.subsample(
tf.cast(cls_weights, tf.bool),
self._first_stage_minibatch_size, tf.cast(cls_targets, tf.bool))
batch_sampled_indices = tf.to_float(tf.map_fn(
_minibatch_subsample_fn,
[batch_cls_targets, batch_cls_weights],
dtype=tf.bool,
parallel_iterations=self._parallel_iterations,
back_prop=True))
# Normalize by number of examples in sampled minibatch
normalizer = tf.reduce_sum(batch_sampled_indices, axis=1)
batch_one_hot_targets = tf.one_hot(
tf.to_int32(batch_cls_targets), depth=2)
sampled_reg_indices = tf.multiply(batch_sampled_indices,
batch_reg_weights)
localization_losses = self._first_stage_localization_loss(
rpn_box_encodings, batch_reg_targets, weights=sampled_reg_indices)
objectness_losses = self._first_stage_objectness_loss(
rpn_objectness_predictions_with_background,
batch_one_hot_targets, weights=batch_sampled_indices)
localization_loss = tf.reduce_mean(
tf.reduce_sum(localization_losses, axis=1) / normalizer)
objectness_loss = tf.reduce_mean(
tf.reduce_sum(objectness_losses, axis=1) / normalizer)
loss_dict = {
'first_stage_localization_loss':
self._first_stage_loc_loss_weight * localization_loss,
'first_stage_objectness_loss':
self._first_stage_obj_loss_weight * objectness_loss,
}
return loss_dict
def _loss_box_classifier(self,
refined_box_encodings,
class_predictions_with_background,
proposal_boxes,
num_proposals,
groundtruth_boxlists,
groundtruth_classes_with_background_list):
"""Computes scalar box classifier loss tensors.
Uses self._detector_target_assigner to obtain regression and classification
targets for the second stage box classifier, optionally performs
hard mining, and returns losses. All losses are computed independently
for each image and then averaged across the batch.
This function assumes that the proposal boxes in the "padded" regions are
actually zero (and thus should not be matched to).
Args:
refined_box_encodings: a 3-D tensor with shape
[total_num_proposals, num_classes, box_coder.code_size] representing
predicted (final) refined box encodings.
class_predictions_with_background: a 3-D tensor with shape
[total_num_proposals, num_classes + 1] containing class
predictions (logits) for each of the anchors. Note that this tensor
*includes* background class predictions (at class index 0).
proposal_boxes: [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
groundtruth_boxlists: a list of BoxLists containing coordinates of the
groundtruth boxes.
groundtruth_classes_with_background_list: a list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes + 1] containing the
class targets with the 0th index assumed to map to the background class.
Returns:
a dictionary mapping loss keys ('second_stage_localization_loss',
'second_stage_classification_loss') to scalar tensors representing
corresponding loss values.
"""
with tf.name_scope('BoxClassifierLoss'):
paddings_indicator = self._padded_batched_proposals_indicator(
num_proposals, self.max_num_proposals)
proposal_boxlists = [
box_list.BoxList(proposal_boxes_single_image)
for proposal_boxes_single_image in tf.unstack(proposal_boxes)]
batch_size = len(proposal_boxlists)
num_proposals_or_one = tf.to_float(tf.expand_dims(
tf.maximum(num_proposals, tf.ones_like(num_proposals)), 1))
normalizer = tf.tile(num_proposals_or_one,
[1, self.max_num_proposals]) * batch_size
(batch_cls_targets_with_background, batch_cls_weights, batch_reg_targets,
batch_reg_weights, _) = target_assigner.batch_assign_targets(
self._detector_target_assigner, proposal_boxlists,
groundtruth_boxlists, groundtruth_classes_with_background_list)
# We only predict refined location encodings for the non background
# classes, but we now pad it to make it compatible with the class
# predictions
flat_cls_targets_with_background = tf.reshape(
batch_cls_targets_with_background,
[batch_size * self.max_num_proposals, -1])
refined_box_encodings_with_background = tf.pad(
refined_box_encodings, [[0, 0], [1, 0], [0, 0]])
refined_box_encodings_masked_by_class_targets = tf.boolean_mask(
refined_box_encodings_with_background,
tf.greater(flat_cls_targets_with_background, 0))
reshaped_refined_box_encodings = tf.reshape(
refined_box_encodings_masked_by_class_targets,
[batch_size, -1, 4])
second_stage_loc_losses = self._second_stage_localization_loss(
reshaped_refined_box_encodings,
batch_reg_targets, weights=batch_reg_weights) / normalizer
second_stage_cls_losses = self._second_stage_classification_loss(
class_predictions_with_background,
batch_cls_targets_with_background,
weights=batch_cls_weights) / normalizer
second_stage_loc_loss = tf.reduce_sum(
tf.boolean_mask(second_stage_loc_losses, paddings_indicator))
second_stage_cls_loss = tf.reduce_sum(
tf.boolean_mask(second_stage_cls_losses, paddings_indicator))
if self._hard_example_miner:
(second_stage_loc_loss, second_stage_cls_loss
) = self._unpad_proposals_and_apply_hard_mining(
proposal_boxlists, second_stage_loc_losses,
second_stage_cls_losses, num_proposals)
loss_dict = {
'second_stage_localization_loss':
(self._second_stage_loc_loss_weight * second_stage_loc_loss),
'second_stage_classification_loss':
(self._second_stage_cls_loss_weight * second_stage_cls_loss),
}
return loss_dict
def _padded_batched_proposals_indicator(self,
num_proposals,
max_num_proposals):
"""Creates indicator matrix of non-pad elements of padded batch proposals.
Args:
num_proposals: Tensor of type tf.int32 with shape [batch_size].
max_num_proposals: Maximum number of proposals per image (integer).
Returns:
A Tensor of type tf.bool with shape [batch_size, max_num_proposals].
"""
batch_size = tf.size(num_proposals)
tiled_num_proposals = tf.tile(
tf.expand_dims(num_proposals, 1), [1, max_num_proposals])
tiled_proposal_index = tf.tile(
tf.expand_dims(tf.range(max_num_proposals), 0), [batch_size, 1])
return tf.greater(tiled_num_proposals, tiled_proposal_index)
def _unpad_proposals_and_apply_hard_mining(self,
proposal_boxlists,
second_stage_loc_losses,
second_stage_cls_losses,
num_proposals):
"""Unpads proposals and applies hard mining.
Args:
proposal_boxlists: A list of `batch_size` BoxLists each representing
`self.max_num_proposals` representing decoded proposal bounding boxes
for each image.
second_stage_loc_losses: A Tensor of type `float32`. A tensor of shape
`[batch_size, self.max_num_proposals]` representing per-anchor
second stage localization loss values.
second_stage_cls_losses: A Tensor of type `float32`. A tensor of shape
`[batch_size, self.max_num_proposals]` representing per-anchor
second stage classification loss values.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
Returns:
second_stage_loc_loss: A scalar float32 tensor representing the second
stage localization loss.
second_stage_cls_loss: A scalar float32 tensor representing the second
stage classification loss.
"""
for (proposal_boxlist, single_image_loc_loss, single_image_cls_loss,
single_image_num_proposals) in zip(
proposal_boxlists,
tf.unstack(second_stage_loc_losses),
tf.unstack(second_stage_cls_losses),
tf.unstack(num_proposals)):
proposal_boxlist = box_list.BoxList(
tf.slice(proposal_boxlist.get(),
[0, 0], [single_image_num_proposals, -1]))
single_image_loc_loss = tf.slice(single_image_loc_loss,
[0], [single_image_num_proposals])
single_image_cls_loss = tf.slice(single_image_cls_loss,
[0], [single_image_num_proposals])
return self._hard_example_miner(
location_losses=tf.expand_dims(single_image_loc_loss, 0),
cls_losses=tf.expand_dims(single_image_cls_loss, 0),
decoded_boxlist_list=[proposal_boxlist])
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Returns callable for loading a checkpoint into the tensorflow graph.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a detection checkpoint
(with compatible variable names) or to restore from a classification
checkpoint for initialization prior to training. Note that when
from_detection_checkpoint=True, the current implementation only
supports restoration from an (exactly) identical model (with exception
of the num_classes parameter).
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
"""
if not from_detection_checkpoint:
return self._feature_extractor.restore_from_classification_checkpoint_fn(
checkpoint_path,
self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope)
variables_to_restore = tf.global_variables()
variables_to_restore.append(slim.get_or_create_global_step())
# Only load feature extractor variables to be consistent with loading from
# a classification checkpoint.
first_stage_variables = tf.contrib.framework.filter_variables(
variables_to_restore,
include_patterns=[self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope])
saver = tf.train.Saver(first_stage_variables)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
class FasterRCNNMetaArchTest(
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase):
def test_postprocess_second_stage_only_inference_mode_with_masks(self):
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
batch_size = 2
total_num_padded_proposals = batch_size * model.max_num_proposals
proposal_boxes = tf.constant(
[[[1, 1, 2, 3],
[0, 0, 1, 1],
[.5, .5, .6, .6],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0]],
[[2, 3, 6, 8],
[1, 2, 5, 3],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]], dtype=tf.float32)
num_proposals = tf.constant([3, 2], dtype=tf.int32)
refined_box_encodings = tf.zeros(
[total_num_padded_proposals, model.num_classes, 4], dtype=tf.float32)
class_predictions_with_background = tf.ones(
[total_num_padded_proposals, model.num_classes+1], dtype=tf.float32)
image_shape = tf.constant([batch_size, 36, 48, 3], dtype=tf.int32)
mask_height = 2
mask_width = 2
mask_predictions = .6 * tf.ones(
[total_num_padded_proposals, model.num_classes,
mask_height, mask_width], dtype=tf.float32)
exp_detection_masks = [[[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]]],
[[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[1, 1], [1, 1]],
[[0, 0], [0, 0]]]]
detections = model.postprocess({
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'num_proposals': num_proposals,
'proposal_boxes': proposal_boxes,
'image_shape': image_shape,
'mask_predictions': mask_predictions
})
with self.test_session() as sess:
detections_out = sess.run(detections)
self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4])
self.assertAllClose(detections_out['detection_scores'],
[[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
self.assertAllClose(detections_out['detection_classes'],
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]])
self.assertAllClose(detections_out['num_detections'], [5, 4])
self.assertAllClose(detections_out['detection_masks'],
exp_detection_masks)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from object_detection.anchor_generators import grid_anchor_generator
from object_detection.builders import box_predictor_builder
from object_detection.builders import hyperparams_builder
from object_detection.builders import post_processing_builder
from object_detection.core import losses
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.protos import box_predictor_pb2
from object_detection.protos import hyperparams_pb2
from object_detection.protos import post_processing_pb2
slim = tf.contrib.slim
BOX_CODE_SIZE = 4
class FakeFasterRCNNFeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Fake feature extracture to use in tests."""
def __init__(self):
super(FakeFasterRCNNFeatureExtractor, self).__init__(
is_training=False,
first_stage_features_stride=32,
reuse_weights=None,
weight_decay=0.0)
def preprocess(self, resized_inputs):
return tf.identity(resized_inputs)
def _extract_proposal_features(self, preprocessed_inputs, scope):
with tf.variable_scope('mock_model'):
return 0 * slim.conv2d(preprocessed_inputs,
num_outputs=3, kernel_size=1, scope='layer1')
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
with tf.variable_scope('mock_model'):
return 0 * slim.conv2d(proposal_feature_maps,
num_outputs=3, kernel_size=1, scope='layer2')
class FasterRCNNMetaArchTestBase(tf.test.TestCase):
"""Base class to test Faster R-CNN and R-FCN meta architectures."""
def _build_arg_scope_with_hyperparams(self,
hyperparams_text_proto,
is_training):
hyperparams = hyperparams_pb2.Hyperparams()
text_format.Merge(hyperparams_text_proto, hyperparams)
return hyperparams_builder.build(hyperparams, is_training=is_training)
def _get_second_stage_box_predictor_text_proto(self):
box_predictor_text_proto = """
mask_rcnn_box_predictor {
fc_hyperparams {
op: FC
activation: NONE
regularizer {
l2_regularizer {
weight: 0.0005
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
"""
return box_predictor_text_proto
def _get_second_stage_box_predictor(self, num_classes, is_training):
box_predictor_proto = box_predictor_pb2.BoxPredictor()
text_format.Merge(self._get_second_stage_box_predictor_text_proto(),
box_predictor_proto)
return box_predictor_builder.build(
hyperparams_builder.build,
box_predictor_proto,
num_classes=num_classes,
is_training=is_training)
def _get_model(self, box_predictor, **common_kwargs):
return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=3,
maxpool_kernel_size=1,
maxpool_stride=1,
second_stage_mask_rcnn_box_predictor=box_predictor,
**common_kwargs)
def _build_model(self,
is_training,
first_stage_only,
second_stage_batch_size,
first_stage_max_proposals=8,
num_classes=2,
hard_mining=False):
def image_resizer_fn(image):
return tf.identity(image)
# anchors in this test are designed so that a subset of anchors are inside
# the image and a subset of anchors are outside.
first_stage_anchor_scales = (0.001, 0.005, 0.1)
first_stage_anchor_aspect_ratios = (0.5, 1.0, 2.0)
first_stage_anchor_strides = (1, 1)
first_stage_anchor_generator = grid_anchor_generator.GridAnchorGenerator(
first_stage_anchor_scales,
first_stage_anchor_aspect_ratios,
anchor_stride=first_stage_anchor_strides)
fake_feature_extractor = FakeFasterRCNNFeatureExtractor()
first_stage_box_predictor_hyperparams_text_proto = """
op: CONV
activation: RELU
regularizer {
l2_regularizer {
weight: 0.00004
}
}
initializer {
truncated_normal_initializer {
stddev: 0.03
}
}
"""
first_stage_box_predictor_arg_scope = (
self._build_arg_scope_with_hyperparams(
first_stage_box_predictor_hyperparams_text_proto, is_training))
first_stage_box_predictor_kernel_size = 3
first_stage_atrous_rate = 1
first_stage_box_predictor_depth = 512
first_stage_minibatch_size = 3
first_stage_positive_balance_fraction = .5
first_stage_nms_score_threshold = -1.0
first_stage_nms_iou_threshold = 1.0
first_stage_max_proposals = first_stage_max_proposals
first_stage_localization_loss_weight = 1.0
first_stage_objectness_loss_weight = 1.0
post_processing_text_proto = """
batch_non_max_suppression {
score_threshold: -20.0
iou_threshold: 1.0
max_detections_per_class: 5
max_total_detections: 5
}
"""
post_processing_config = post_processing_pb2.PostProcessing()
text_format.Merge(post_processing_text_proto, post_processing_config)
second_stage_non_max_suppression_fn, _ = post_processing_builder.build(
post_processing_config)
second_stage_balance_fraction = 1.0
second_stage_score_conversion_fn = tf.identity
second_stage_localization_loss_weight = 1.0
second_stage_classification_loss_weight = 1.0
hard_example_miner = None
if hard_mining:
hard_example_miner = losses.HardExampleMiner(
num_hard_examples=1,
iou_threshold=0.99,
loss_type='both',
cls_loss_weight=second_stage_classification_loss_weight,
loc_loss_weight=second_stage_localization_loss_weight,
max_negatives_per_positive=None)
common_kwargs = {
'is_training': is_training,
'num_classes': num_classes,
'image_resizer_fn': image_resizer_fn,
'feature_extractor': fake_feature_extractor,
'first_stage_only': first_stage_only,
'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope':
first_stage_box_predictor_arg_scope,
'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
'first_stage_minibatch_size': first_stage_minibatch_size,
'first_stage_positive_balance_fraction':
first_stage_positive_balance_fraction,
'first_stage_nms_score_threshold': first_stage_nms_score_threshold,
'first_stage_nms_iou_threshold': first_stage_nms_iou_threshold,
'first_stage_max_proposals': first_stage_max_proposals,
'first_stage_localization_loss_weight':
first_stage_localization_loss_weight,
'first_stage_objectness_loss_weight':
first_stage_objectness_loss_weight,
'second_stage_batch_size': second_stage_batch_size,
'second_stage_balance_fraction': second_stage_balance_fraction,
'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
'second_stage_localization_loss_weight':
second_stage_localization_loss_weight,
'second_stage_classification_loss_weight':
second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner}
return self._get_model(self._get_second_stage_box_predictor(
num_classes=num_classes, is_training=is_training), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only(
self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=True, second_stage_batch_size=2)
batch_size = 2
height = 10
width = 12
input_image_shape = (batch_size, height, width, 3)
preprocessed_inputs = tf.placeholder(dtype=tf.float32,
shape=(batch_size, None, None, 3))
prediction_dict = model.predict(preprocessed_inputs)
# In inference mode, anchors are clipped to the image window, but not
# pruned. Since MockFasterRCNN.extract_proposal_features returns a
# tensor with the same shape as its input, the expected number of anchors
# is height * width * the number of anchors per location (i.e. 3x3).
expected_num_anchors = height * width * 3 * 3
expected_output_keys = set([
'rpn_box_predictor_features', 'rpn_features_to_crop', 'image_shape',
'rpn_box_encodings', 'rpn_objectness_predictions_with_background',
'anchors'])
expected_output_shapes = {
'rpn_box_predictor_features': (batch_size, height, width, 512),
'rpn_features_to_crop': (batch_size, height, width, 3),
'rpn_box_encodings': (batch_size, expected_num_anchors, 4),
'rpn_objectness_predictions_with_background':
(batch_size, expected_num_anchors, 2),
'anchors': (expected_num_anchors, 4)
}
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_inputs:
np.zeros(input_image_shape)
})
self.assertEqual(set(prediction_out.keys()), expected_output_keys)
self.assertAllEqual(prediction_out['image_shape'], input_image_shape)
for output_key, expected_shape in expected_output_shapes.iteritems():
self.assertAllEqual(prediction_out[output_key].shape, expected_shape)
# Check that anchors are clipped to window.
anchors = prediction_out['anchors']
self.assertTrue(np.all(np.greater_equal(anchors, 0)))
self.assertTrue(np.all(np.less_equal(anchors[:, 0], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 1], width)))
self.assertTrue(np.all(np.less_equal(anchors[:, 2], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 3], width)))
def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=True, first_stage_only=True, second_stage_batch_size=2)
batch_size = 2
height = 10
width = 12
input_image_shape = (batch_size, height, width, 3)
preprocessed_inputs = tf.placeholder(dtype=tf.float32,
shape=(batch_size, None, None, 3))
prediction_dict = model.predict(preprocessed_inputs)
expected_output_keys = set([
'rpn_box_predictor_features', 'rpn_features_to_crop', 'image_shape',
'rpn_box_encodings', 'rpn_objectness_predictions_with_background',
'anchors'])
# At training time, anchors that exceed image bounds are pruned. Thus
# the `expected_num_anchors` in the above inference mode test is now
# a strict upper bound on the number of anchors.
num_anchors_strict_upper_bound = height * width * 3 * 3
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
preprocessed_inputs:
np.zeros(input_image_shape)
})
self.assertEqual(set(prediction_out.keys()), expected_output_keys)
self.assertAllEqual(prediction_out['image_shape'], input_image_shape)
# Check that anchors have less than the upper bound and
# are clipped to window.
anchors = prediction_out['anchors']
self.assertTrue(len(anchors.shape) == 2 and anchors.shape[1] == 4)
num_anchors_out = anchors.shape[0]
self.assertTrue(num_anchors_out < num_anchors_strict_upper_bound)
self.assertTrue(np.all(np.greater_equal(anchors, 0)))
self.assertTrue(np.all(np.less_equal(anchors[:, 0], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 1], width)))
self.assertTrue(np.all(np.less_equal(anchors[:, 2], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 3], width)))
self.assertAllEqual(prediction_out['rpn_box_encodings'].shape,
(batch_size, num_anchors_out, 4))
self.assertAllEqual(
prediction_out['rpn_objectness_predictions_with_background'].shape,
(batch_size, num_anchors_out, 2))
def test_predict_gives_correct_shapes_in_inference_mode_both_stages(self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=2)
batch_size = 2
image_size = 10
image_shape = (batch_size, image_size, image_size, 3)
preprocessed_inputs = tf.zeros(image_shape, dtype=tf.float32)
result_tensor_dict = model.predict(preprocessed_inputs)
expected_num_anchors = image_size * image_size * 3 * 3
expected_shapes = {
'rpn_box_predictor_features':
(2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3),
'image_shape': (4,),
'rpn_box_encodings': (2, expected_num_anchors, 4),
'rpn_objectness_predictions_with_background':
(2, expected_num_anchors, 2),
'anchors': (expected_num_anchors, 4),
'refined_box_encodings': (2 * 8, 2, 4),
'class_predictions_with_background': (2 * 8, 2 + 1),
'num_proposals': (2,),
'proposal_boxes': (2, 8, 4),
}
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
tensor_dict_out = sess.run(result_tensor_dict)
self.assertEqual(set(tensor_dict_out.keys()),
set(expected_shapes.keys()))
for key in expected_shapes:
self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key])
def test_predict_gives_correct_shapes_in_train_mode_both_stages(self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=True, first_stage_only=False, second_stage_batch_size=7)
batch_size = 2
image_size = 10
image_shape = (batch_size, image_size, image_size, 3)
preprocessed_inputs = tf.zeros(image_shape, dtype=tf.float32)
groundtruth_boxes_list = [
tf.constant([[0, 0, .5, .5], [.5, .5, 1, 1]], dtype=tf.float32),
tf.constant([[0, .5, .5, 1], [.5, 0, 1, .5]], dtype=tf.float32)]
groundtruth_classes_list = [
tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
result_tensor_dict = model.predict(preprocessed_inputs)
expected_shapes = {
'rpn_box_predictor_features':
(2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3),
'image_shape': (4,),
'refined_box_encodings': (2 * 7, 2, 4),
'class_predictions_with_background': (2 * 7, 2 + 1),
'num_proposals': (2,),
'proposal_boxes': (2, 7, 4),
}
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
tensor_dict_out = sess.run(result_tensor_dict)
self.assertEqual(set(tensor_dict_out.keys()),
set(expected_shapes.keys()).union(set([
'rpn_box_encodings',
'rpn_objectness_predictions_with_background',
'anchors'])))
for key in expected_shapes:
self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key])
anchors_shape_out = tensor_dict_out['anchors'].shape
self.assertEqual(2, len(anchors_shape_out))
self.assertEqual(4, anchors_shape_out[1])
num_anchors_out = anchors_shape_out[0]
self.assertAllEqual(tensor_dict_out['rpn_box_encodings'].shape,
(2, num_anchors_out, 4))
self.assertAllEqual(
tensor_dict_out['rpn_objectness_predictions_with_background'].shape,
(2, num_anchors_out, 2))
def test_postprocess_first_stage_only_inference_mode(self):
model = self._build_model(
is_training=False, first_stage_only=True, second_stage_batch_size=6)
batch_size = 2
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size, anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant([
[[-10, 13],
[10, -10],
[10, -11],
[-10, 12]],
[[10, -10],
[-10, 13],
[-10, 12],
[10, -11]]], dtype=tf.float32)
rpn_features_to_crop = tf.ones((batch_size, 8, 8, 10), dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
proposals = model.postprocess({
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'rpn_features_to_crop': rpn_features_to_crop,
'anchors': anchors,
'image_shape': image_shape})
expected_proposal_boxes = [
[[0, 0, .5, .5], [.5, .5, 1, 1], [0, .5, .5, 1], [.5, 0, 1.0, .5]]
+ 4 * [4 * [0]],
[[0, .5, .5, 1], [.5, 0, 1.0, .5], [0, 0, .5, .5], [.5, .5, 1, 1]]
+ 4 * [4 * [0]]]
expected_proposal_scores = [[1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0]]
expected_num_proposals = [4, 4]
expected_output_keys = set(['detection_boxes', 'detection_scores',
'num_detections'])
self.assertEqual(set(proposals.keys()), expected_output_keys)
with self.test_session() as sess:
proposals_out = sess.run(proposals)
self.assertAllClose(proposals_out['detection_boxes'],
expected_proposal_boxes)
self.assertAllClose(proposals_out['detection_scores'],
expected_proposal_scores)
self.assertAllEqual(proposals_out['num_detections'],
expected_num_proposals)
def test_postprocess_first_stage_only_train_mode(self):
model = self._build_model(
is_training=True, first_stage_only=True, second_stage_batch_size=2)
batch_size = 2
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size, anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant([
[[-10, 13],
[-10, 12],
[-10, 11],
[-10, 10]],
[[-10, 13],
[-10, 12],
[-10, 11],
[-10, 10]]], dtype=tf.float32)
rpn_features_to_crop = tf.ones((batch_size, 8, 8, 10), dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
groundtruth_boxes_list = [
tf.constant([[0, 0, .5, .5], [.5, .5, 1, 1]], dtype=tf.float32),
tf.constant([[0, .5, .5, 1], [.5, 0, 1, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
proposals = model.postprocess({
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'rpn_features_to_crop': rpn_features_to_crop,
'anchors': anchors,
'image_shape': image_shape})
expected_proposal_boxes = [
[[0, 0, .5, .5], [.5, .5, 1, 1]], [[0, .5, .5, 1], [.5, 0, 1, .5]]]
expected_proposal_scores = [[1, 1],
[1, 1]]
expected_num_proposals = [2, 2]
expected_output_keys = set(['detection_boxes', 'detection_scores',
'num_detections'])
self.assertEqual(set(proposals.keys()), expected_output_keys)
with self.test_session() as sess:
proposals_out = sess.run(proposals)
self.assertAllClose(proposals_out['detection_boxes'],
expected_proposal_boxes)
self.assertAllClose(proposals_out['detection_scores'],
expected_proposal_scores)
self.assertAllEqual(proposals_out['num_detections'],
expected_num_proposals)
def test_postprocess_second_stage_only_inference_mode(self):
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
batch_size = 2
total_num_padded_proposals = batch_size * model.max_num_proposals
proposal_boxes = tf.constant(
[[[1, 1, 2, 3],
[0, 0, 1, 1],
[.5, .5, .6, .6],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0]],
[[2, 3, 6, 8],
[1, 2, 5, 3],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]], dtype=tf.float32)
num_proposals = tf.constant([3, 2], dtype=tf.int32)
refined_box_encodings = tf.zeros(
[total_num_padded_proposals, model.num_classes, 4], dtype=tf.float32)
class_predictions_with_background = tf.ones(
[total_num_padded_proposals, model.num_classes+1], dtype=tf.float32)
image_shape = tf.constant([batch_size, 36, 48, 3], dtype=tf.int32)
detections = model.postprocess({
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'num_proposals': num_proposals,
'proposal_boxes': proposal_boxes,
'image_shape': image_shape
})
with self.test_session() as sess:
detections_out = sess.run(detections)
self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4])
self.assertAllClose(detections_out['detection_scores'],
[[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
self.assertAllClose(detections_out['detection_classes'],
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]])
self.assertAllClose(detections_out['num_detections'], [5, 4])
def test_loss_first_stage_only_mode(self):
model = self._build_model(
is_training=True, first_stage_only=True, second_stage_batch_size=6)
batch_size = 2
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size,
anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant([
[[-10, 13],
[10, -10],
[10, -11],
[-10, 12]],
[[10, -10],
[-10, 13],
[-10, 12],
[10, -11]]], dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
groundtruth_boxes_list = [
tf.constant([[0, 0, .5, .5], [.5, .5, 1, 1]], dtype=tf.float32),
tf.constant([[0, .5, .5, 1], [.5, 0, 1, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
prediction_dict = {
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'image_shape': image_shape,
'anchors': anchors
}
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
loss_dict = model.loss(prediction_dict)
with self.test_session() as sess:
loss_dict_out = sess.run(loss_dict)
self.assertAllClose(loss_dict_out['first_stage_localization_loss'], 0)
self.assertAllClose(loss_dict_out['first_stage_objectness_loss'], 0)
self.assertTrue('second_stage_localization_loss' not in loss_dict_out)
self.assertTrue('second_stage_classification_loss' not in loss_dict_out)
def test_loss_full(self):
model = self._build_model(
is_training=True, first_stage_only=False, second_stage_batch_size=6)
batch_size = 2
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size,
anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant([
[[-10, 13],
[10, -10],
[10, -11],
[-10, 12]],
[[10, -10],
[-10, 13],
[-10, 12],
[10, -11]]], dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
num_proposals = tf.constant([6, 6], dtype=tf.int32)
proposal_boxes = tf.constant(
2 * [[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32],
[0, 0, 16, 16],
[0, 16, 16, 32]]], dtype=tf.float32)
refined_box_encodings = tf.zeros(
(batch_size * model.max_num_proposals,
model.num_classes,
BOX_CODE_SIZE), dtype=tf.float32)
class_predictions_with_background = tf.constant(
[[-10, 10, -10], # first image
[10, -10, -10],
[10, -10, -10],
[-10, -10, 10],
[-10, 10, -10],
[10, -10, -10],
[10, -10, -10], # second image
[-10, 10, -10],
[-10, 10, -10],
[10, -10, -10],
[10, -10, -10],
[-10, 10, -10]], dtype=tf.float32)
groundtruth_boxes_list = [
tf.constant([[0, 0, .5, .5], [.5, .5, 1, 1]], dtype=tf.float32),
tf.constant([[0, .5, .5, 1], [.5, 0, 1, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0], [0, 1]], dtype=tf.float32),
tf.constant([[1, 0], [1, 0]], dtype=tf.float32)]
prediction_dict = {
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'image_shape': image_shape,
'anchors': anchors,
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'proposal_boxes': proposal_boxes,
'num_proposals': num_proposals
}
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
loss_dict = model.loss(prediction_dict)
with self.test_session() as sess:
loss_dict_out = sess.run(loss_dict)
self.assertAllClose(loss_dict_out['first_stage_localization_loss'], 0)
self.assertAllClose(loss_dict_out['first_stage_objectness_loss'], 0)
self.assertAllClose(loss_dict_out['second_stage_localization_loss'], 0)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_loss_full_zero_padded_proposals(self):
model = self._build_model(
is_training=True, first_stage_only=False, second_stage_batch_size=6)
batch_size = 1
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size,
anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant([
[[-10, 13],
[10, -10],
[10, -11],
[10, -12]],], dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
# box_classifier_batch_size is 6, but here we assume that the number of
# actual proposals (not counting zero paddings) is fewer (3).
num_proposals = tf.constant([3], dtype=tf.int32)
proposal_boxes = tf.constant(
[[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[0, 0, 0, 0], # begin paddings
[0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=tf.float32)
refined_box_encodings = tf.zeros(
(batch_size * model.max_num_proposals,
model.num_classes,
BOX_CODE_SIZE), dtype=tf.float32)
class_predictions_with_background = tf.constant(
[[-10, 10, -10],
[10, -10, -10],
[10, -10, -10],
[0, 0, 0], # begin paddings
[0, 0, 0],
[0, 0, 0]], dtype=tf.float32)
groundtruth_boxes_list = [
tf.constant([[0, 0, .5, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0]], dtype=tf.float32)]
prediction_dict = {
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'image_shape': image_shape,
'anchors': anchors,
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'proposal_boxes': proposal_boxes,
'num_proposals': num_proposals
}
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
loss_dict = model.loss(prediction_dict)
with self.test_session() as sess:
loss_dict_out = sess.run(loss_dict)
self.assertAllClose(loss_dict_out['first_stage_localization_loss'], 0)
self.assertAllClose(loss_dict_out['first_stage_objectness_loss'], 0)
self.assertAllClose(loss_dict_out['second_stage_localization_loss'], 0)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_loss_full_zero_padded_proposals_nonzero_loss_with_two_images(self):
model = self._build_model(
is_training=True, first_stage_only=False, second_stage_batch_size=6)
batch_size = 2
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size,
anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant(
[[[-10, 13],
[10, -10],
[10, -11],
[10, -12]],
[[-10, 13],
[10, -10],
[10, -11],
[10, -12]]], dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
# box_classifier_batch_size is 6, but here we assume that the number of
# actual proposals (not counting zero paddings) is fewer (3).
num_proposals = tf.constant([3, 2], dtype=tf.int32)
proposal_boxes = tf.constant(
[[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[0, 0, 0, 0], # begin paddings
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 16, 16],
[0, 16, 16, 32],
[0, 0, 0, 0],
[0, 0, 0, 0], # begin paddings
[0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=tf.float32)
refined_box_encodings = tf.zeros(
(batch_size * model.max_num_proposals,
model.num_classes,
BOX_CODE_SIZE), dtype=tf.float32)
class_predictions_with_background = tf.constant(
[[-10, 10, -10], # first image
[10, -10, -10],
[10, -10, -10],
[0, 0, 0], # begin paddings
[0, 0, 0],
[0, 0, 0],
[-10, -10, 10], # second image
[10, -10, -10],
[0, 0, 0], # begin paddings
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],], dtype=tf.float32)
# The first groundtruth box is 4/5 of the anchor size in both directions
# experiencing a loss of:
# 2 * SmoothL1(5 * log(4/5)) / num_proposals
# = 2 * (abs(5 * log(1/2)) - .5) / 3
# The second groundtruth box is identical to the prediction and thus
# experiences zero loss.
# Total average loss is (abs(5 * log(1/2)) - .5) / 3.
groundtruth_boxes_list = [
tf.constant([[0.05, 0.05, 0.45, 0.45]], dtype=tf.float32),
tf.constant([[0.0, 0.0, 0.5, 0.5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0]], dtype=tf.float32),
tf.constant([[0, 1]], dtype=tf.float32)]
exp_loc_loss = (-5 * np.log(.8) - 0.5) / 3.0
prediction_dict = {
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'image_shape': image_shape,
'anchors': anchors,
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'proposal_boxes': proposal_boxes,
'num_proposals': num_proposals
}
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
loss_dict = model.loss(prediction_dict)
with self.test_session() as sess:
loss_dict_out = sess.run(loss_dict)
self.assertAllClose(loss_dict_out['first_stage_localization_loss'],
exp_loc_loss)
self.assertAllClose(loss_dict_out['first_stage_objectness_loss'], 0)
self.assertAllClose(loss_dict_out['second_stage_localization_loss'],
exp_loc_loss)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_loss_with_hard_mining(self):
model = self._build_model(is_training=True,
first_stage_only=False,
second_stage_batch_size=None,
first_stage_max_proposals=6,
hard_mining=True)
batch_size = 1
anchors = tf.constant(
[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[16, 16, 32, 32]], dtype=tf.float32)
rpn_box_encodings = tf.zeros(
[batch_size,
anchors.get_shape().as_list()[0],
BOX_CODE_SIZE], dtype=tf.float32)
# use different numbers for the objectness category to break ties in
# order of boxes returned by NMS
rpn_objectness_predictions_with_background = tf.constant(
[[[-10, 13],
[-10, 12],
[10, -11],
[10, -12]]], dtype=tf.float32)
image_shape = tf.constant([batch_size, 32, 32, 3], dtype=tf.int32)
# box_classifier_batch_size is 6, but here we assume that the number of
# actual proposals (not counting zero paddings) is fewer (3).
num_proposals = tf.constant([3], dtype=tf.int32)
proposal_boxes = tf.constant(
[[[0, 0, 16, 16],
[0, 16, 16, 32],
[16, 0, 32, 16],
[0, 0, 0, 0], # begin paddings
[0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=tf.float32)
refined_box_encodings = tf.zeros(
(batch_size * model.max_num_proposals,
model.num_classes,
BOX_CODE_SIZE), dtype=tf.float32)
class_predictions_with_background = tf.constant(
[[-10, 10, -10], # first image
[-10, -10, 10],
[10, -10, -10],
[0, 0, 0], # begin paddings
[0, 0, 0],
[0, 0, 0]], dtype=tf.float32)
# The first groundtruth box is 4/5 of the anchor size in both directions
# experiencing a loss of:
# 2 * SmoothL1(5 * log(4/5)) / num_proposals
# = 2 * (abs(5 * log(1/2)) - .5) / 3
# The second groundtruth box is 46/50 of the anchor size in both directions
# experiencing a loss of:
# 2 * SmoothL1(5 * log(42/50)) / num_proposals
# = 2 * (.5(5 * log(.92))^2 - .5) / 3.
# Since the first groundtruth box experiences greater loss, and we have
# set num_hard_examples=1 in the HardMiner, the final localization loss
# corresponds to that of the first groundtruth box.
groundtruth_boxes_list = [
tf.constant([[0.05, 0.05, 0.45, 0.45],
[0.02, 0.52, 0.48, 0.98],], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1, 0], [0, 1]], dtype=tf.float32)]
exp_loc_loss = 2 * (-5 * np.log(.8) - 0.5) / 3.0
prediction_dict = {
'rpn_box_encodings': rpn_box_encodings,
'rpn_objectness_predictions_with_background':
rpn_objectness_predictions_with_background,
'image_shape': image_shape,
'anchors': anchors,
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'proposal_boxes': proposal_boxes,
'num_proposals': num_proposals
}
model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
loss_dict = model.loss(prediction_dict)
with self.test_session() as sess:
loss_dict_out = sess.run(loss_dict)
self.assertAllClose(loss_dict_out['second_stage_localization_loss'],
exp_loc_loss)
self.assertAllClose(loss_dict_out['second_stage_classification_loss'], 0)
def test_restore_fn_classification(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
with tf.variable_scope('mock_model'):
net = slim.conv2d(image, num_outputs=3, kernel_size=1, scope='layer1')
slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
# Create tensorflow detection graph and load variables from
# classification checkpoint.
test_graph_detection = tf.Graph()
with test_graph_detection.as_default():
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
inputs_shape = (2, 20, 20, 3)
inputs = tf.to_float(tf.random_uniform(
inputs_shape, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict)
restore_fn = model.restore_fn(saved_model_path,
from_detection_checkpoint=False)
with self.test_session() as sess:
restore_fn(sess)
def test_restore_fn_detection(self):
# Define first detection graph and save variables.
test_graph_detection1 = tf.Graph()
with test_graph_detection1.as_default():
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
inputs_shape = (2, 20, 20, 3)
inputs = tf.to_float(tf.random_uniform(
inputs_shape, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs = model.preprocess(inputs)
prediction_dict = model.predict(preprocessed_inputs)
model.postprocess(prediction_dict)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
# Define second detection graph and restore variables.
test_graph_detection2 = tf.Graph()
with test_graph_detection2.as_default():
model2 = self._build_model(is_training=False, first_stage_only=False,
second_stage_batch_size=6, num_classes=42)
inputs_shape2 = (2, 20, 20, 3)
inputs2 = tf.to_float(tf.random_uniform(
inputs_shape2, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs2 = model2.preprocess(inputs2)
prediction_dict2 = model2.predict(preprocessed_inputs2)
model2.postprocess(prediction_dict2)
restore_fn = model2.restore_fn(saved_model_path,
from_detection_checkpoint=True)
with self.test_session() as sess:
restore_fn(sess)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn(model2.first_stage_feature_extractor_scope, var.name)
self.assertNotIn(model2.second_stage_feature_extractor_scope,
var.name)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""R-FCN meta-architecture definition.
R-FCN: Dai, Jifeng, et al. "R-FCN: Object Detection via Region-based
Fully Convolutional Networks." arXiv preprint arXiv:1605.06409 (2016).
The R-FCN meta architecture is similar to Faster R-CNN and only differs in the
second stage. Hence this class inherits FasterRCNNMetaArch and overrides only
the `_predict_second_stage` method.
Similar to Faster R-CNN we allow for two modes: first_stage_only=True and
first_stage_only=False. In the former setting, all of the user facing methods
(e.g., predict, postprocess, loss) can be used as if the model consisted
only of the RPN, returning class agnostic proposals (these can be thought of as
approximate detections with no associated class information). In the latter
setting, proposals are computed, then passed through a second stage
"box classifier" to yield (multi-class) detections.
Implementations of R-FCN models must define a new FasterRCNNFeatureExtractor and
override three methods: `preprocess`, `_extract_proposal_features` (the first
stage of the model), and `_extract_box_classifier_features` (the second stage of
the model). Optionally, the `restore_fn` method can be overridden. See tests
for an example.
See notes in the documentation of Faster R-CNN meta-architecture as they all
apply here.
"""
import tensorflow as tf
from object_detection.core import box_predictor
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import ops
class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
"""R-FCN Meta-architecture definition."""
def __init__(self,
is_training,
num_classes,
image_resizer_fn,
feature_extractor,
first_stage_only,
first_stage_anchor_generator,
first_stage_atrous_rate,
first_stage_box_predictor_arg_scope,
first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth,
first_stage_minibatch_size,
first_stage_positive_balance_fraction,
first_stage_nms_score_threshold,
first_stage_nms_iou_threshold,
first_stage_max_proposals,
first_stage_localization_loss_weight,
first_stage_objectness_loss_weight,
second_stage_rfcn_box_predictor,
second_stage_batch_size,
second_stage_balance_fraction,
second_stage_non_max_suppression_fn,
second_stage_score_conversion_fn,
second_stage_localization_loss_weight,
second_stage_classification_loss_weight,
hard_example_miner,
parallel_iterations=16):
"""RFCNMetaArch Constructor.
Args:
is_training: A boolean indicating whether the training version of the
computation graph should be constructed.
num_classes: Number of classes. Note that num_classes *does not*
include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}).
image_resizer_fn: A callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions.
See builders/image_resizer_builder.py.
feature_extractor: A FasterRCNNFeatureExtractor object.
first_stage_only: Whether to construct only the Region Proposal Network
(RPN) part of the model.
first_stage_anchor_generator: An anchor_generator.AnchorGenerator object
(note that currently we only support
grid_anchor_generator.GridAnchorGenerator objects)
first_stage_atrous_rate: A single integer indicating the atrous rate for
the single convolution op which is applied to the `rpn_features_to_crop`
tensor to obtain a tensor to be used for box prediction. Some feature
extractors optionally allow for producing feature maps computed at
denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d,
separable_conv2d and fully_connected ops for the RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op
just prior to RPN box predictions.
first_stage_minibatch_size: The "batch size" to use for computing the
objectness and location loss of the region proposal network. This
"batch size" refers to the number of anchors selected as contributing
to the loss function for any given image within the image batch and is
only called "batch_size" due to terminology from the Faster R-CNN paper.
first_stage_positive_balance_fraction: Fraction of positive examples
per image for the RPN. The recommended value for Faster RCNN is 0.5.
first_stage_nms_score_threshold: Score threshold for non max suppression
for the Region Proposal Network (RPN). This value is expected to be in
[0, 1] as it is applied directly after a softmax transformation. The
recommended value for Faster R-CNN is 0.
first_stage_nms_iou_threshold: The Intersection Over Union (IOU) threshold
for performing Non-Max Suppression (NMS) on the boxes predicted by the
Region Proposal Network (RPN).
first_stage_max_proposals: Maximum number of boxes to retain after
performing Non-Max Suppression (NMS) on the boxes predicted by the
Region Proposal Network (RPN).
first_stage_localization_loss_weight: A float
first_stage_objectness_loss_weight: A float
second_stage_rfcn_box_predictor: RFCN box predictor to use for
second stage.
second_stage_batch_size: The batch size used for computing the
classification and refined location loss of the box classifier. This
"batch size" refers to the number of proposals selected as contributing
to the loss function for any given image within the image batch and is
only called "batch_size" due to terminology from the Faster R-CNN paper.
second_stage_balance_fraction: Fraction of positive examples to use
per image for the box classifier. The recommended value for Faster RCNN
is 0.25.
second_stage_non_max_suppression_fn: batch_multiclass_non_max_suppression
callable that takes `boxes`, `scores`, optional `clip_window` and
optional (kwarg) `mask` inputs (with all other inputs already set)
and returns a dictionary containing tensors with keys:
`detection_boxes`, `detection_scores`, `detection_classes`,
`num_detections`, and (optionally) `detection_masks`. See
`post_processing.batch_multiclass_non_max_suppression` for the type and
shape of these tensors.
second_stage_score_conversion_fn: Callable elementwise nonlinearity
(that takes tensors as inputs and returns tensors). This is usually
used to convert logits to probabilities.
second_stage_localization_loss_weight: A float
second_stage_classification_loss_weight: A float
hard_example_miner: A losses.HardExampleMiner object (can be None).
parallel_iterations: (Optional) The number of iterations allowed to run
in parallel for calls to tf.map_fn.
Raises:
ValueError: If `second_stage_batch_size` > `first_stage_max_proposals`
ValueError: If first_stage_anchor_generator is not of type
grid_anchor_generator.GridAnchorGenerator.
"""
super(RFCNMetaArch, self).__init__(
is_training,
num_classes,
image_resizer_fn,
feature_extractor,
first_stage_only,
first_stage_anchor_generator,
first_stage_atrous_rate,
first_stage_box_predictor_arg_scope,
first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth,
first_stage_minibatch_size,
first_stage_positive_balance_fraction,
first_stage_nms_score_threshold,
first_stage_nms_iou_threshold,
first_stage_max_proposals,
first_stage_localization_loss_weight,
first_stage_objectness_loss_weight,
None, # initial_crop_size is not used in R-FCN
None, # maxpool_kernel_size is not use in R-FCN
None, # maxpool_stride is not use in R-FCN
None, # fully_connected_box_predictor is not used in R-FCN.
second_stage_batch_size,
second_stage_balance_fraction,
second_stage_non_max_suppression_fn,
second_stage_score_conversion_fn,
second_stage_localization_loss_weight,
second_stage_classification_loss_weight,
hard_example_miner,
parallel_iterations)
self._rfcn_box_predictor = second_stage_rfcn_box_predictor
def _predict_second_stage(self, rpn_box_encodings,
rpn_objectness_predictions_with_background,
rpn_features,
anchors,
image_shape):
"""Predicts the output tensors from 2nd stage of FasterRCNN.
Args:
rpn_box_encodings: 4-D float tensor of shape
[batch_size, num_valid_anchors, self._box_coder.code_size] containing
predicted boxes.
rpn_objectness_predictions_with_background: 2-D float tensor of shape
[batch_size, num_valid_anchors, 2] containing class
predictions (logits) for each of the anchors. Note that this
tensor *includes* background class predictions (at class index 0).
rpn_features: A 4-D float32 tensor with shape
[batch_size, height, width, depth] representing image features from the
RPN.
anchors: 2-D float tensor of shape
[num_anchors, self._box_coder.code_size].
image_shape: A 1D int32 tensors of size [4] containing the image shape.
Returns:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape
[total_num_proposals, num_classes, 4] representing predicted
(final) refined box encodings, where
total_num_proposals=batch_size*self._max_num_proposals
2) class_predictions_with_background: a 3-D tensor with shape
[total_num_proposals, num_classes + 1] containing class
predictions (logits) for each of the anchors, where
total_num_proposals=batch_size*self._max_num_proposals.
Note that this tensor *includes* background class predictions
(at class index 0).
3) num_proposals: An int32 tensor of shape [batch_size] representing the
number of proposals generated by the RPN. `num_proposals` allows us
to keep track of which entries are to be treated as zero paddings and
which are not since we always pad the number of proposals to be
`self.max_num_proposals` for each image.
4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes (in absolute coordinates).
"""
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
rpn_box_encodings, rpn_objectness_predictions_with_background,
anchors, image_shape)
box_classifier_features = (
self._feature_extractor.extract_box_classifier_features(
rpn_features,
scope=self.second_stage_feature_extractor_scope))
box_predictions = self._rfcn_box_predictor.predict(
box_classifier_features,
num_predictions_per_location=1,
scope=self.second_stage_box_predictor_scope,
proposal_boxes=proposal_boxes_normalized)
refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1)
class_predictions_with_background = tf.squeeze(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1)
absolute_proposal_boxes = ops.normalized_to_image_coordinates(
proposal_boxes_normalized, image_shape,
parallel_iterations=self._parallel_iterations)
prediction_dict = {
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background':
class_predictions_with_background,
'num_proposals': num_proposals,
'proposal_boxes': absolute_proposal_boxes,
}
return prediction_dict
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.meta_architectures.rfcn_meta_arch."""
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
from object_detection.meta_architectures import rfcn_meta_arch
class RFCNMetaArchTest(
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase):
def _get_second_stage_box_predictor_text_proto(self):
box_predictor_text_proto = """
rfcn_box_predictor {
conv_hyperparams {
op: CONV
activation: NONE
regularizer {
l2_regularizer {
weight: 0.0005
}
}
initializer {
variance_scaling_initializer {
factor: 1.0
uniform: true
mode: FAN_AVG
}
}
}
}
"""
return box_predictor_text_proto
def _get_model(self, box_predictor, **common_kwargs):
return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=box_predictor, **common_kwargs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SSD Meta-architecture definition.
General tensorflow implementation of convolutional Multibox/SSD detection
models.
"""
from abc import abstractmethod
import re
import tensorflow as tf
from object_detection.core import box_coder as bcoder
from object_detection.core import box_list
from object_detection.core import box_predictor as bpredictor
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import variables_helper
slim = tf.contrib.slim
class SSDFeatureExtractor(object):
"""SSD Feature Extractor definition."""
def __init__(self,
depth_multiplier,
min_depth,
conv_hyperparams,
reuse_weights=None):
self._depth_multiplier = depth_multiplier
self._min_depth = min_depth
self._conv_hyperparams = conv_hyperparams
self._reuse_weights = reuse_weights
@abstractmethod
def preprocess(self, resized_inputs):
"""Preprocesses images for feature extraction (minus image resizing).
Args:
resized_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
"""
pass
@abstractmethod
def extract_features(self, preprocessed_inputs):
"""Extracts features from preprocessed inputs.
This function is responsible for extracting feature maps from preprocessed
images.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
"""
pass
class SSDMetaArch(model.DetectionModel):
"""SSD Meta-architecture definition."""
def __init__(self,
is_training,
anchor_generator,
box_predictor,
box_coder,
feature_extractor,
matcher,
region_similarity_calculator,
image_resizer_fn,
non_max_suppression_fn,
score_conversion_fn,
classification_loss,
localization_loss,
classification_loss_weight,
localization_loss_weight,
normalize_loss_by_num_matches,
hard_example_miner,
add_summaries=True):
"""SSDMetaArch Constructor.
TODO: group NMS parameters + score converter into
a class and loss parameters into a class and write config protos for
postprocessing and losses.
Args:
is_training: A boolean indicating whether the training version of the
computation graph should be constructed.
anchor_generator: an anchor_generator.AnchorGenerator object.
box_predictor: a box_predictor.BoxPredictor object.
box_coder: a box_coder.BoxCoder object.
feature_extractor: a SSDFeatureExtractor object.
matcher: a matcher.Matcher object.
region_similarity_calculator: a
region_similarity_calculator.RegionSimilarityCalculator object.
image_resizer_fn: a callable for image resizing. This callable always
takes a rank-3 image tensor (corresponding to a single image) and
returns a rank-3 image tensor, possibly with new spatial dimensions.
See builders/image_resizer_builder.py.
non_max_suppression_fn: batch_multiclass_non_max_suppression
callable that takes `boxes`, `scores` and optional `clip_window`
inputs (with all other inputs already set) and returns a dictionary
hold tensors with keys: `detection_boxes`, `detection_scores`,
`detection_classes` and `num_detections`. See `post_processing.
batch_multiclass_non_max_suppression` for the type and shape of these
tensors.
score_conversion_fn: callable elementwise nonlinearity (that takes tensors
as inputs and returns tensors). This is usually used to convert logits
to probabilities.
classification_loss: an object_detection.core.losses.Loss object.
localization_loss: a object_detection.core.losses.Loss object.
classification_loss_weight: float
localization_loss_weight: float
normalize_loss_by_num_matches: boolean
hard_example_miner: a losses.HardExampleMiner object (can be None)
add_summaries: boolean (default: True) controlling whether summary ops
should be added to tensorflow graph.
"""
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training
# Needed for fine-tuning from classification checkpoints whose
# variables do not have the feature extractor scope.
self._extract_features_scope = 'FeatureExtractor'
self._anchor_generator = anchor_generator
self._box_predictor = box_predictor
self._box_coder = box_coder
self._feature_extractor = feature_extractor
self._matcher = matcher
self._region_similarity_calculator = region_similarity_calculator
# TODO: handle agnostic mode and positive/negative class weights
unmatched_cls_target = None
unmatched_cls_target = tf.constant([1] + self.num_classes * [0], tf.float32)
self._target_assigner = target_assigner.TargetAssigner(
self._region_similarity_calculator,
self._matcher,
self._box_coder,
positive_class_weight=1.0,
negative_class_weight=1.0,
unmatched_cls_target=unmatched_cls_target)
self._classification_loss = classification_loss
self._localization_loss = localization_loss
self._classification_loss_weight = classification_loss_weight
self._localization_loss_weight = localization_loss_weight
self._normalize_loss_by_num_matches = normalize_loss_by_num_matches
self._hard_example_miner = hard_example_miner
self._image_resizer_fn = image_resizer_fn
self._non_max_suppression_fn = non_max_suppression_fn
self._score_conversion_fn = score_conversion_fn
self._anchors = None
self._add_summaries = add_summaries
@property
def anchors(self):
if not self._anchors:
raise RuntimeError('anchors have not been constructed yet!')
if not isinstance(self._anchors, box_list.BoxList):
raise RuntimeError('anchors should be a BoxList object, but is not.')
return self._anchors
def preprocess(self, inputs):
"""Feature-extractor specific preprocessing.
See base class.
Args:
inputs: a [batch, height_in, width_in, channels] float tensor representing
a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: a [batch, height_out, width_out, channels] float
tensor representing a batch of images.
Raises:
ValueError: if inputs tensor does not have type tf.float32
"""
if inputs.dtype is not tf.float32:
raise ValueError('`preprocess` expects a tf.float32 tensor')
with tf.name_scope('Preprocessor'):
# TODO: revisit whether to always use batch size as the number of
# parallel iterations vs allow for dynamic batching.
resized_inputs = tf.map_fn(self._image_resizer_fn,
elems=inputs,
dtype=tf.float32)
return self._feature_extractor.preprocess(resized_inputs)
def predict(self, preprocessed_inputs):
"""Predicts unpostprocessed tensors from input tensor.
This function takes an input batch of images and runs it through the forward
pass of the network to yield unpostprocessesed predictions.
A side effect of calling the predict method is that self._anchors is
populated with a box_list.BoxList of anchors. These anchors must be
constructed before the postprocess or loss functions can be called.
Args:
preprocessed_inputs: a [batch, height, width, channels] image tensor.
Returns:
prediction_dict: a dictionary holding "raw" prediction tensors:
1) box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
box_code_dimension] containing predicted boxes.
2) class_predictions_with_background: 3-D float tensor of shape
[batch_size, num_anchors, num_classes+1] containing class predictions
(logits) for each of the anchors. Note that this tensor *includes*
background class predictions (at class index 0).
3) feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i].
"""
with tf.variable_scope(None, self._extract_features_scope,
[preprocessed_inputs]):
feature_maps = self._feature_extractor.extract_features(
preprocessed_inputs)
feature_map_spatial_dims = self._get_feature_map_spatial_dims(feature_maps)
self._anchors = self._anchor_generator.generate(feature_map_spatial_dims)
(box_encodings, class_predictions_with_background
) = self._add_box_predictions_to_feature_maps(feature_maps)
predictions_dict = {
'box_encodings': box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'feature_maps': feature_maps
}
return predictions_dict
def _add_box_predictions_to_feature_maps(self, feature_maps):
"""Adds box predictors to each feature map and returns concatenated results.
Args:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
Returns:
box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
box_code_dimension] containing predicted boxes.
class_predictions_with_background: 2-D float tensor of shape
[batch_size, num_anchors, num_classes+1] containing class predictions
(logits) for each of the anchors. Note that this tensor *includes*
background class predictions (at class index 0).
Raises:
RuntimeError: if the number of feature maps extracted via the
extract_features method does not match the length of the
num_anchors_per_locations list that was passed to the constructor.
RuntimeError: if box_encodings from the box_predictor does not have
shape of the form [batch_size, num_anchors, 1, code_size].
"""
num_anchors_per_location_list = (
self._anchor_generator.num_anchors_per_location())
if len(feature_maps) != len(num_anchors_per_location_list):
raise RuntimeError('the number of feature maps must match the '
'length of self.anchors.NumAnchorsPerLocation().')
box_encodings_list = []
cls_predictions_with_background_list = []
for idx, (feature_map, num_anchors_per_location
) in enumerate(zip(feature_maps, num_anchors_per_location_list)):
box_predictor_scope = 'BoxPredictor_{}'.format(idx)
box_predictions = self._box_predictor.predict(feature_map,
num_anchors_per_location,
box_predictor_scope)
box_encodings = box_predictions[bpredictor.BOX_ENCODINGS]
cls_predictions_with_background = box_predictions[
bpredictor.CLASS_PREDICTIONS_WITH_BACKGROUND]
box_encodings_shape = box_encodings.get_shape().as_list()
if len(box_encodings_shape) != 4 or box_encodings_shape[2] != 1:
raise RuntimeError('box_encodings from the box_predictor must be of '
'shape `[batch_size, num_anchors, 1, code_size]`; '
'actual shape', box_encodings_shape)
box_encodings = tf.squeeze(box_encodings, axis=2)
box_encodings_list.append(box_encodings)
cls_predictions_with_background_list.append(
cls_predictions_with_background)
num_predictions = sum(
[tf.shape(box_encodings)[1] for box_encodings in box_encodings_list])
num_anchors = self.anchors.num_boxes()
anchors_assert = tf.assert_equal(num_anchors, num_predictions, [
'Mismatch: number of anchors vs number of predictions', num_anchors,
num_predictions
])
with tf.control_dependencies([anchors_assert]):
box_encodings = tf.concat(box_encodings_list, 1)
class_predictions_with_background = tf.concat(
cls_predictions_with_background_list, 1)
return box_encodings, class_predictions_with_background
def _get_feature_map_spatial_dims(self, feature_maps):
"""Return list of spatial dimensions for each feature map in a list.
Args:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i].
Returns:
a list of pairs (height, width) for each feature map in feature_maps
"""
feature_map_shapes = [
feature_map.get_shape().as_list() for feature_map in feature_maps
]
return [(shape[1], shape[2]) for shape in feature_map_shapes]
def postprocess(self, prediction_dict):
"""Converts prediction tensors to final detections.
This function converts raw predictions tensors to final detection results by
slicing off the background class, decoding box predictions and applying
non max suppression and clipping to the image window.
See base class for output format conventions. Note also that by default,
scores are to be interpreted as logits, but if a score_conversion_fn is
used, then scores are remapped (and may thus have a different
interpretation).
Args:
prediction_dict: a dictionary holding prediction tensors with
1) box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
box_code_dimension] containing predicted boxes.
2) class_predictions_with_background: 2-D float tensor of shape
[batch_size, num_anchors, num_classes+1] containing class predictions
(logits) for each of the anchors. Note that this tensor *includes*
background class predictions.
Returns:
detections: a dictionary containing the following fields
detection_boxes: [batch, max_detection, 4]
detection_scores: [batch, max_detections]
detection_classes: [batch, max_detections]
num_detections: [batch]
Raises:
ValueError: if prediction_dict does not contain `box_encodings` or
`class_predictions_with_background` fields.
"""
if ('box_encodings' not in prediction_dict or
'class_predictions_with_background' not in prediction_dict):
raise ValueError('prediction_dict does not contain expected entries.')
with tf.name_scope('Postprocessor'):
box_encodings = prediction_dict['box_encodings']
class_predictions = prediction_dict['class_predictions_with_background']
detection_boxes = bcoder.batch_decode(box_encodings, self._box_coder,
self.anchors)
detection_boxes = tf.expand_dims(detection_boxes, axis=2)
class_predictions_without_background = tf.slice(class_predictions,
[0, 0, 1],
[-1, -1, -1])
detection_scores = self._score_conversion_fn(
class_predictions_without_background)
clip_window = tf.constant([0, 0, 1, 1], tf.float32)
detections = self._non_max_suppression_fn(detection_boxes,
detection_scores,
clip_window=clip_window)
return detections
def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors with respect to provided groundtruth.
Calling this function requires that groundtruth tensors have been
provided via the provide_groundtruth function.
Args:
prediction_dict: a dictionary holding prediction tensors with
1) box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
box_code_dimension] containing predicted boxes.
2) class_predictions_with_background: 2-D float tensor of shape
[batch_size, num_anchors, num_classes+1] containing class predictions
(logits) for each of the anchors. Note that this tensor *includes*
background class predictions.
scope: Optional scope name.
Returns:
a dictionary mapping loss keys (`localization_loss` and
`classification_loss`) to scalar tensors representing corresponding loss
values.
"""
with tf.name_scope(scope, 'Loss', prediction_dict.values()):
(batch_cls_targets, batch_cls_weights, batch_reg_targets,
batch_reg_weights, match_list) = self._assign_targets(
self.groundtruth_lists(fields.BoxListFields.boxes),
self.groundtruth_lists(fields.BoxListFields.classes))
if self._add_summaries:
self._summarize_input(
self.groundtruth_lists(fields.BoxListFields.boxes), match_list)
num_matches = tf.stack(
[match.num_matched_columns() for match in match_list])
location_losses = self._localization_loss(
prediction_dict['box_encodings'],
batch_reg_targets,
weights=batch_reg_weights)
cls_losses = self._classification_loss(
prediction_dict['class_predictions_with_background'],
batch_cls_targets,
weights=batch_cls_weights)
# Optionally apply hard mining on top of loss values
localization_loss = tf.reduce_sum(location_losses)
classification_loss = tf.reduce_sum(cls_losses)
if self._hard_example_miner:
(localization_loss, classification_loss) = self._apply_hard_mining(
location_losses, cls_losses, prediction_dict, match_list)
if self._add_summaries:
self._hard_example_miner.summarize()
# Optionally normalize by number of positive matches
normalizer = tf.constant(1.0, dtype=tf.float32)
if self._normalize_loss_by_num_matches:
normalizer = tf.maximum(tf.to_float(tf.reduce_sum(num_matches)), 1.0)
loss_dict = {
'localization_loss': (self._localization_loss_weight / normalizer) *
localization_loss,
'classification_loss': (self._classification_loss_weight /
normalizer) * classification_loss
}
return loss_dict
def _assign_targets(self, groundtruth_boxes_list, groundtruth_classes_list):
"""Assign groundtruth targets.
Adds a background class to each one-hot encoding of groundtruth classes
and uses target assigner to obtain regression and classification targets.
Args:
groundtruth_boxes_list: a list of 2-D tensors of shape [num_boxes, 4]
containing coordinates of the groundtruth boxes.
Groundtruth boxes are provided in [y_min, x_min, y_max, x_max]
format and assumed to be normalized and clipped
relative to the image window with y_min <= y_max and x_min <= x_max.
groundtruth_classes_list: a list of 2-D one-hot (or k-hot) tensors of
shape [num_boxes, num_classes] containing the class targets with the 0th
index assumed to map to the first non-background class.
Returns:
batch_cls_targets: a tensor with shape [batch_size, num_anchors,
num_classes],
batch_cls_weights: a tensor with shape [batch_size, num_anchors],
batch_reg_targets: a tensor with shape [batch_size, num_anchors,
box_code_dimension]
batch_reg_weights: a tensor with shape [batch_size, num_anchors],
match_list: a list of matcher.Match objects encoding the match between
anchors and groundtruth boxes for each image of the batch,
with rows of the Match objects corresponding to groundtruth boxes
and columns corresponding to anchors.
"""
groundtruth_boxlists = [
box_list.BoxList(boxes) for boxes in groundtruth_boxes_list
]
groundtruth_classes_with_background_list = [
tf.pad(one_hot_encoding, [[0, 0], [1, 0]], mode='CONSTANT')
for one_hot_encoding in groundtruth_classes_list
]
return target_assigner.batch_assign_targets(
self._target_assigner, self.anchors, groundtruth_boxlists,
groundtruth_classes_with_background_list)
def _summarize_input(self, groundtruth_boxes_list, match_list):
"""Creates tensorflow summaries for the input boxes and anchors.
This function creates four summaries corresponding to the average
number (over images in a batch) of (1) groundtruth boxes, (2) anchors
marked as positive, (3) anchors marked as negative, and (4) anchors marked
as ignored.
Args:
groundtruth_boxes_list: a list of 2-D tensors of shape [num_boxes, 4]
containing corners of the groundtruth boxes.
match_list: a list of matcher.Match objects encoding the match between
anchors and groundtruth boxes for each image of the batch,
with rows of the Match objects corresponding to groundtruth boxes
and columns corresponding to anchors.
"""
num_boxes_per_image = tf.stack(
[tf.shape(x)[0] for x in groundtruth_boxes_list])
pos_anchors_per_image = tf.stack(
[match.num_matched_columns() for match in match_list])
neg_anchors_per_image = tf.stack(
[match.num_unmatched_columns() for match in match_list])
ignored_anchors_per_image = tf.stack(
[match.num_ignored_columns() for match in match_list])
tf.summary.scalar('Input/AvgNumGroundtruthBoxesPerImage',
tf.reduce_mean(tf.to_float(num_boxes_per_image)))
tf.summary.scalar('Input/AvgNumPositiveAnchorsPerImage',
tf.reduce_mean(tf.to_float(pos_anchors_per_image)))
tf.summary.scalar('Input/AvgNumNegativeAnchorsPerImage',
tf.reduce_mean(tf.to_float(neg_anchors_per_image)))
tf.summary.scalar('Input/AvgNumIgnoredAnchorsPerImage',
tf.reduce_mean(tf.to_float(ignored_anchors_per_image)))
def _apply_hard_mining(self, location_losses, cls_losses, prediction_dict,
match_list):
"""Applies hard mining to anchorwise losses.
Args:
location_losses: Float tensor of shape [batch_size, num_anchors]
representing anchorwise location losses.
cls_losses: Float tensor of shape [batch_size, num_anchors]
representing anchorwise classification losses.
prediction_dict: p a dictionary holding prediction tensors with
1) box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
box_code_dimension] containing predicted boxes.
2) class_predictions_with_background: 2-D float tensor of shape
[batch_size, num_anchors, num_classes+1] containing class predictions
(logits) for each of the anchors. Note that this tensor *includes*
background class predictions.
match_list: a list of matcher.Match objects encoding the match between
anchors and groundtruth boxes for each image of the batch,
with rows of the Match objects corresponding to groundtruth boxes
and columns corresponding to anchors.
Returns:
mined_location_loss: a float scalar with sum of localization losses from
selected hard examples.
mined_cls_loss: a float scalar with sum of classification losses from
selected hard examples.
"""
class_pred_shape = [-1, self.anchors.num_boxes_static(), self.num_classes]
class_predictions = tf.reshape(
tf.slice(prediction_dict['class_predictions_with_background'],
[0, 0, 1], class_pred_shape), class_pred_shape)
decoded_boxes = bcoder.batch_decode(prediction_dict['box_encodings'],
self._box_coder, self.anchors)
decoded_box_tensors_list = tf.unstack(decoded_boxes)
class_prediction_list = tf.unstack(class_predictions)
decoded_boxlist_list = []
for box_location, box_score in zip(decoded_box_tensors_list,
class_prediction_list):
decoded_boxlist = box_list.BoxList(box_location)
decoded_boxlist.add_field('scores', box_score)
decoded_boxlist_list.append(decoded_boxlist)
return self._hard_example_miner(
location_losses=location_losses,
cls_losses=cls_losses,
decoded_boxlist_list=decoded_boxlist_list,
match_list=match_list)
def restore_fn(self, checkpoint_path, from_detection_checkpoint=True):
"""Return callable for loading a checkpoint into the tensorflow graph.
Args:
checkpoint_path: path to checkpoint to restore.
from_detection_checkpoint: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
"""
variables_to_restore = {}
for variable in tf.all_variables():
if variable.op.name.startswith(self._extract_features_scope):
var_name = variable.op.name
if not from_detection_checkpoint:
var_name = (
re.split('^' + self._extract_features_scope + '/', var_name)[-1])
variables_to_restore[var_name] = variable
# TODO: Load variables selectively using scopes.
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.meta_architectures.ssd_meta_arch."""
import functools
import numpy as np
import tensorflow as tf
from tensorflow.python.training import saver as tf_saver
from object_detection.core import anchor_generator
from object_detection.core import box_list
from object_detection.core import losses
from object_detection.core import post_processing
from object_detection.core import region_similarity_calculator as sim_calc
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.utils import test_utils
slim = tf.contrib.slim
class FakeSSDFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
def __init__(self):
super(FakeSSDFeatureExtractor, self).__init__(
depth_multiplier=0, min_depth=0, conv_hyperparams=None)
def preprocess(self, resized_inputs):
return tf.identity(resized_inputs)
def extract_features(self, preprocessed_inputs):
with tf.variable_scope('mock_model'):
features = slim.conv2d(inputs=preprocessed_inputs, num_outputs=32,
kernel_size=[1, 1], scope='layer1')
return [features]
class MockAnchorGenerator2x2(anchor_generator.AnchorGenerator):
"""Sets up a simple 2x2 anchor grid on the unit square."""
def name_scope(self):
return 'MockAnchorGenerator'
def num_anchors_per_location(self):
return [1]
def _generate(self, feature_map_shape_list):
return box_list.BoxList(
tf.constant([[0, 0, .5, .5],
[0, .5, .5, 1],
[.5, 0, 1, .5],
[.5, .5, 1, 1]], tf.float32))
class SsdMetaArchTest(tf.test.TestCase):
def setUp(self):
"""Set up mock SSD model.
Here we set up a simple mock SSD model that will always predict 4
detections that happen to always be exactly the anchors that are set up
in the above MockAnchorGenerator. Because we let max_detections=5,
we will also always end up with an extra padded row in the detection
results.
"""
is_training = False
self._num_classes = 1
mock_anchor_generator = MockAnchorGenerator2x2()
mock_box_predictor = test_utils.MockBoxPredictor(
is_training, self._num_classes)
mock_box_coder = test_utils.MockBoxCoder()
fake_feature_extractor = FakeSSDFeatureExtractor()
mock_matcher = test_utils.MockMatcher()
region_similarity_calculator = sim_calc.IouSimilarity()
def image_resizer_fn(image):
return tf.identity(image)
classification_loss = losses.WeightedSigmoidClassificationLoss(
anchorwise_output=True)
localization_loss = losses.WeightedSmoothL1LocalizationLoss(
anchorwise_output=True)
non_max_suppression_fn = functools.partial(
post_processing.batch_multiclass_non_max_suppression,
score_thresh=-20.0,
iou_thresh=1.0,
max_size_per_class=5,
max_total_size=5)
classification_loss_weight = 1.0
localization_loss_weight = 1.0
normalize_loss_by_num_matches = False
# This hard example miner is expected to be a no-op.
hard_example_miner = losses.HardExampleMiner(
num_hard_examples=None,
iou_threshold=1.0)
self._num_anchors = 4
self._code_size = 4
self._model = ssd_meta_arch.SSDMetaArch(
is_training, mock_anchor_generator, mock_box_predictor, mock_box_coder,
fake_feature_extractor, mock_matcher, region_similarity_calculator,
image_resizer_fn, non_max_suppression_fn, tf.identity,
classification_loss, localization_loss, classification_loss_weight,
localization_loss_weight, normalize_loss_by_num_matches,
hard_example_miner)
def test_predict_results_have_correct_keys_and_shapes(self):
batch_size = 3
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)
self.assertTrue('box_encodings' in prediction_dict)
self.assertTrue('class_predictions_with_background' in prediction_dict)
self.assertTrue('feature_maps' in prediction_dict)
expected_box_encodings_shape_out = (
batch_size, self._num_anchors, self._code_size)
expected_class_predictions_with_background_shape_out = (
batch_size, self._num_anchors, self._num_classes+1)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict)
self.assertAllEqual(prediction_out['box_encodings'].shape,
expected_box_encodings_shape_out)
self.assertAllEqual(
prediction_out['class_predictions_with_background'].shape,
expected_class_predictions_with_background_shape_out)
def test_postprocess_results_are_correct(self):
batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
prediction_dict = self._model.predict(preprocessed_input)
detections = self._model.postprocess(prediction_dict)
expected_boxes = np.array([[[0, 0, .5, .5],
[0, .5, .5, 1],
[.5, 0, 1, .5],
[.5, .5, 1, 1],
[0, 0, 0, 0]],
[[0, 0, .5, .5],
[0, .5, .5, 1],
[.5, 0, 1, .5],
[.5, .5, 1, 1],
[0, 0, 0, 0]]])
expected_scores = np.array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
expected_classes = np.array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
expected_num_detections = np.array([4, 4])
self.assertTrue('detection_boxes' in detections)
self.assertTrue('detection_scores' in detections)
self.assertTrue('detection_classes' in detections)
self.assertTrue('num_detections' in detections)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
detections_out = sess.run(detections)
self.assertAllClose(detections_out['detection_boxes'], expected_boxes)
self.assertAllClose(detections_out['detection_scores'], expected_scores)
self.assertAllClose(detections_out['detection_classes'], expected_classes)
self.assertAllClose(detections_out['num_detections'],
expected_num_detections)
def test_loss_results_are_correct(self):
batch_size = 2
preprocessed_input = tf.random_uniform((batch_size, 2, 2, 3),
dtype=tf.float32)
groundtruth_boxes_list = [tf.constant([[0, 0, .5, .5]], dtype=tf.float32),
tf.constant([[0, 0, .5, .5]], dtype=tf.float32)]
groundtruth_classes_list = [tf.constant([[1]], dtype=tf.float32),
tf.constant([[1]], dtype=tf.float32)]
self._model.provide_groundtruth(groundtruth_boxes_list,
groundtruth_classes_list)
prediction_dict = self._model.predict(preprocessed_input)
loss_dict = self._model.loss(prediction_dict)
self.assertTrue('localization_loss' in loss_dict)
self.assertTrue('classification_loss' in loss_dict)
expected_localization_loss = 0.0
expected_classification_loss = (batch_size * self._num_anchors
* (self._num_classes+1) * np.log(2.0))
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
losses_out = sess.run(loss_dict)
self.assertAllClose(losses_out['localization_loss'],
expected_localization_loss)
self.assertAllClose(losses_out['classification_loss'],
expected_classification_loss)
def test_restore_fn_detection(self):
init_op = tf.global_variables_initializer()
saver = tf_saver.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=True)
restore_fn(sess)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
def test_restore_fn_classification(self):
# Define mock tensorflow classification graph and save variables.
test_graph_classification = tf.Graph()
with test_graph_classification.as_default():
image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3])
with tf.variable_scope('mock_model'):
net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1')
slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2')
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
save_path = self.get_temp_dir()
with self.test_session() as sess:
sess.run(init_op)
saved_model_path = saver.save(sess, save_path)
# Create tensorflow detection graph and load variables from
# classification checkpoint.
test_graph_detection = tf.Graph()
with test_graph_detection.as_default():
inputs_shape = [2, 2, 2, 3]
inputs = tf.to_float(tf.random_uniform(
inputs_shape, minval=0, maxval=255, dtype=tf.int32))
preprocessed_inputs = self._model.preprocess(inputs)
prediction_dict = self._model.predict(preprocessed_inputs)
self._model.postprocess(prediction_dict)
restore_fn = self._model.restore_fn(saved_model_path,
from_detection_checkpoint=False)
with self.test_session() as sess:
restore_fn(sess)
for var in sess.run(tf.report_uninitialized_variables()):
self.assertNotIn('FeatureExtractor', var.name)
if __name__ == '__main__':
tf.test.main()
# Tensorflow Object Detection API: Models.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "feature_map_generators",
srcs = [
"feature_map_generators.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
],
)
py_test(
name = "feature_map_generators_test",
srcs = [
"feature_map_generators_test.py",
],
deps = [
":feature_map_generators",
"//tensorflow",
],
)
py_library(
name = "ssd_feature_extractor_test",
srcs = [
"ssd_feature_extractor_test.py",
],
deps = [
"//tensorflow",
],
)
py_library(
name = "ssd_inception_v2_feature_extractor",
srcs = [
"ssd_inception_v2_feature_extractor.py",
],
deps = [
":feature_map_generators",
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:ssd_meta_arch",
"//tensorflow_models/slim:inception_v2",
],
)
py_library(
name = "ssd_mobilenet_v1_feature_extractor",
srcs = ["ssd_mobilenet_v1_feature_extractor.py"],
deps = [
":feature_map_generators",
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:ssd_meta_arch",
"//tensorflow_models/slim:mobilenet_v1",
],
)
py_test(
name = "ssd_inception_v2_feature_extractor_test",
srcs = [
"ssd_inception_v2_feature_extractor_test.py",
],
deps = [
":ssd_feature_extractor_test",
":ssd_inception_v2_feature_extractor",
"//tensorflow",
],
)
py_test(
name = "ssd_mobilenet_v1_feature_extractor_test",
srcs = ["ssd_mobilenet_v1_feature_extractor_test.py"],
deps = [
":ssd_feature_extractor_test",
":ssd_mobilenet_v1_feature_extractor",
"//tensorflow",
],
)
py_library(
name = "faster_rcnn_inception_resnet_v2_feature_extractor",
srcs = [
"faster_rcnn_inception_resnet_v2_feature_extractor.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/object_detection/utils:variables_helper",
"//tensorflow_models/slim:inception_resnet_v2",
],
)
py_test(
name = "faster_rcnn_inception_resnet_v2_feature_extractor_test",
srcs = [
"faster_rcnn_inception_resnet_v2_feature_extractor_test.py",
],
deps = [
":faster_rcnn_inception_resnet_v2_feature_extractor",
"//tensorflow",
],
)
py_library(
name = "faster_rcnn_resnet_v1_feature_extractor",
srcs = [
"faster_rcnn_resnet_v1_feature_extractor.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow_models/slim:resnet_utils",
"//tensorflow_models/slim:resnet_v1",
],
)
py_test(
name = "faster_rcnn_resnet_v1_feature_extractor_test",
srcs = [
"faster_rcnn_resnet_v1_feature_extractor_test.py",
],
deps = [
":faster_rcnn_resnet_v1_feature_extractor",
"//tensorflow",
],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inception Resnet v2 Faster R-CNN implementation.
See "Inception-v4, Inception-ResNet and the Impact of Residual Connections on
Learning" by Szegedy et al. (https://arxiv.org/abs/1602.07261)
as well as
"Speed/accuracy trade-offs for modern convolutional object detectors" by
Huang et al. (https://arxiv.org/abs/1611.10012)
"""
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import variables_helper
from nets import inception_resnet_v2
slim = tf.contrib.slim
class FasterRCNNInceptionResnetV2FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN with Inception Resnet v2 feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16.
"""
if first_stage_features_stride != 8 and first_stage_features_stride != 16:
raise ValueError('`first_stage_features_stride` must be 8 or 16.')
super(FasterRCNNInceptionResnetV2FeatureExtractor, self).__init__(
is_training, first_stage_features_stride, reuse_weights, weight_decay)
def preprocess(self, resized_inputs):
"""Faster R-CNN with Inception Resnet v2 preprocessing.
Maps pixel values to the range [-1, 1].
Args:
resized_inputs: A [batch, height_in, width_in, channels] float32 tensor
representing a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: A [batch, height_out, width_out, channels] float32
tensor representing a batch of images.
"""
return (2.0 / 255.0) * resized_inputs - 1.0
def _extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features.
Extracts features using the first half of the Inception Resnet v2 network.
We construct the network in `align_feature_maps=True` mode, which means
that all VALID paddings in the network are changed to SAME padding so that
the feature maps are aligned.
Args:
preprocessed_inputs: A [batch, height, width, channels] float32 tensor
representing a batch of images.
scope: A scope name.
Returns:
rpn_feature_map: A tensor with shape [batch, height, width, depth]
Raises:
InvalidArgumentError: If the spatial size of `preprocessed_inputs`
(height or width) is less than 33.
ValueError: If the created network is missing the required activation.
"""
if len(preprocessed_inputs.get_shape().as_list()) != 4:
raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
'tensor of shape %s' % preprocessed_inputs.get_shape())
with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope(
weight_decay=self._weight_decay)):
# Forces is_training to False to disable batch norm update.
with slim.arg_scope([slim.batch_norm], is_training=False):
with tf.variable_scope('InceptionResnetV2',
reuse=self._reuse_weights) as scope:
rpn_feature_map, _ = (
inception_resnet_v2.inception_resnet_v2_base(
preprocessed_inputs, final_endpoint='PreAuxLogits',
scope=scope, output_stride=self._first_stage_features_stride,
align_feature_maps=True))
return rpn_feature_map
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features.
This function reconstructs the "second half" of the Inception ResNet v2
network after the part defined in `_extract_proposal_features`.
Args:
proposal_feature_maps: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, crop_height, crop_width, depth]
representing the feature map cropped to each proposal.
scope: A scope name.
Returns:
proposal_classifier_features: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, height, width, depth]
representing box classifier features for each proposal.
"""
with tf.variable_scope('InceptionResnetV2', reuse=self._reuse_weights):
with slim.arg_scope(inception_resnet_v2.inception_resnet_v2_arg_scope(
weight_decay=self._weight_decay)):
# Forces is_training to False to disable batch norm update.
with slim.arg_scope([slim.batch_norm], is_training=False):
with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
stride=1, padding='SAME'):
with tf.variable_scope('Mixed_7a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(proposal_feature_maps,
256, 1, scope='Conv2d_0a_1x1')
tower_conv_1 = slim.conv2d(
tower_conv, 384, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1 = slim.conv2d(
proposal_feature_maps, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(
tower_conv1, 288, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_conv2 = slim.conv2d(
proposal_feature_maps, 256, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(
tower_conv2_1, 320, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.max_pool2d(
proposal_feature_maps, 3, stride=2, padding='VALID',
scope='MaxPool_1a_3x3')
net = tf.concat(
[tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3)
net = slim.repeat(net, 9, inception_resnet_v2.block8, scale=0.20)
net = inception_resnet_v2.block8(net, activation_fn=None)
proposal_classifier_features = slim.conv2d(
net, 1536, 1, scope='Conv2d_7b_1x1')
return proposal_classifier_features
def restore_from_classification_checkpoint_fn(
self,
checkpoint_path,
first_stage_feature_extractor_scope,
second_stage_feature_extractor_scope):
"""Returns callable for loading a checkpoint into the tensorflow graph.
Note that this overrides the default implementation in
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor which does not work for
InceptionResnetV2 checkpoints.
TODO: revisit whether it's possible to force the `Repeat` namescope as
created in `_extract_box_classifier_features` to start counting at 2 (e.g.
`Repeat_2`) so that the default restore_fn can be used.
Args:
checkpoint_path: Path to checkpoint to restore.
first_stage_feature_extractor_scope: A scope name for the first stage
feature extractor.
second_stage_feature_extractor_scope: A scope name for the second stage
feature extractor.
Returns:
a callable which takes a tf.Session as input and loads a checkpoint when
run.
"""
variables_to_restore = {}
for variable in tf.global_variables():
if variable.op.name.startswith(
first_stage_feature_extractor_scope):
var_name = variable.op.name.replace(
first_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable
if variable.op.name.startswith(
second_stage_feature_extractor_scope):
var_name = variable.op.name.replace(
second_stage_feature_extractor_scope
+ '/InceptionResnetV2/Repeat', 'InceptionResnetV2/Repeat_2')
var_name = var_name.replace(
second_stage_feature_extractor_scope + '/', '')
variables_to_restore[var_name] = variable
variables_to_restore = (
variables_helper.get_variables_available_in_checkpoint(
variables_to_restore, checkpoint_path))
saver = tf.train.Saver(variables_to_restore)
def restore(sess):
saver.restore(sess, checkpoint_path)
return restore
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for models.faster_rcnn_inception_resnet_v2_feature_extractor."""
import tensorflow as tf
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
class FasterRcnnInceptionResnetV2FeatureExtractorTest(tf.test.TestCase):
def _build_feature_extractor(self, first_stage_features_stride):
return frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor(
is_training=False,
first_stage_features_stride=first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0)
def test_extract_proposal_features_returns_expected_size(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.random_uniform(
[1, 299, 299, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 19, 19, 1088])
def test_extract_proposal_features_stride_eight(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=8)
preprocessed_inputs = tf.random_uniform(
[1, 224, 224, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 28, 28, 1088])
def test_extract_proposal_features_half_size_input(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.random_uniform(
[1, 112, 112, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 7, 7, 1088])
def test_extract_proposal_features_dies_on_invalid_stride(self):
with self.assertRaises(ValueError):
self._build_feature_extractor(first_stage_features_stride=99)
def test_extract_proposal_features_dies_with_incorrect_rank_inputs(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.random_uniform(
[224, 224, 3], maxval=255, dtype=tf.float32)
with self.assertRaises(ValueError):
feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
def test_extract_box_classifier_features_returns_expected_size(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
proposal_feature_maps = tf.random_uniform(
[2, 17, 17, 1088], maxval=255, dtype=tf.float32)
proposal_classifier_features = (
feature_extractor.extract_box_classifier_features(
proposal_feature_maps, scope='TestScope'))
features_shape = tf.shape(proposal_classifier_features)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [2, 8, 8, 1536])
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Resnet V1 Faster R-CNN implementation.
See "Deep Residual Learning for Image Recognition" by He et al., 2015.
https://arxiv.org/abs/1512.03385
Note: this implementation assumes that the classification checkpoint used
to finetune this model is trained using the same configuration as that of
the MSRA provided checkpoints
(see https://github.com/KaimingHe/deep-residual-networks), e.g., with
same preprocessing, batch norm scaling, etc.
"""
import tensorflow as tf
from object_detection.meta_architectures import faster_rcnn_meta_arch
from nets import resnet_utils
from nets import resnet_v1
slim = tf.contrib.slim
class FasterRCNNResnetV1FeatureExtractor(
faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
"""Faster R-CNN Resnet V1 feature extractor implementation."""
def __init__(self,
architecture,
resnet_model,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
architecture: Architecture name of the Resnet V1 model.
resnet_model: Definition of the Resnet V1 model.
is_training: See base class.
first_stage_features_stride: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16.
"""
if first_stage_features_stride != 8 and first_stage_features_stride != 16:
raise ValueError('`first_stage_features_stride` must be 8 or 16.')
self._architecture = architecture
self._resnet_model = resnet_model
super(FasterRCNNResnetV1FeatureExtractor, self).__init__(
is_training, first_stage_features_stride, reuse_weights, weight_decay)
def preprocess(self, resized_inputs):
"""Faster R-CNN Resnet V1 preprocessing.
VGG style channel mean subtraction as described here:
https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md
Args:
resized_inputs: A [batch, height_in, width_in, channels] float32 tensor
representing a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: A [batch, height_out, width_out, channels] float32
tensor representing a batch of images.
"""
channel_means = [123.68, 116.779, 103.939]
return resized_inputs - [[channel_means]]
def _extract_proposal_features(self, preprocessed_inputs, scope):
"""Extracts first stage RPN features.
Args:
preprocessed_inputs: A [batch, height, width, channels] float32 tensor
representing a batch of images.
scope: A scope name.
Returns:
rpn_feature_map: A tensor with shape [batch, height, width, depth]
Raises:
InvalidArgumentError: If the spatial size of `preprocessed_inputs`
(height or width) is less than 33.
ValueError: If the created network is missing the required activation.
"""
if len(preprocessed_inputs.get_shape().as_list()) != 4:
raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
'tensor of shape %s' % preprocessed_inputs.get_shape())
shape_assert = tf.Assert(
tf.logical_and(
tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33),
tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)),
['image size must at least be 33 in both height and width.'])
with tf.control_dependencies([shape_assert]):
# Disables batchnorm for fine-tuning with smaller batch sizes.
# TODO: Figure out if it is needed when image batch size is bigger.
with slim.arg_scope(
resnet_utils.resnet_arg_scope(
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
weight_decay=self._weight_decay)):
with tf.variable_scope(
self._architecture, reuse=self._reuse_weights) as var_scope:
_, activations = self._resnet_model(
preprocessed_inputs,
num_classes=None,
is_training=False,
global_pool=False,
output_stride=self._first_stage_features_stride,
scope=var_scope)
handle = scope + '/%s/block3' % self._architecture
return activations[handle]
def _extract_box_classifier_features(self, proposal_feature_maps, scope):
"""Extracts second stage box classifier features.
Args:
proposal_feature_maps: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, crop_height, crop_width, depth]
representing the feature map cropped to each proposal.
scope: A scope name (unused).
Returns:
proposal_classifier_features: A 4-D float tensor with shape
[batch_size * self.max_num_proposals, height, width, depth]
representing box classifier features for each proposal.
"""
with tf.variable_scope(self._architecture, reuse=self._reuse_weights):
with slim.arg_scope(
resnet_utils.resnet_arg_scope(
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
weight_decay=self._weight_decay)):
with slim.arg_scope([slim.batch_norm], is_training=False):
blocks = [
resnet_utils.Block('block4', resnet_v1.bottleneck, [{
'depth': 2048,
'depth_bottleneck': 512,
'stride': 1
}] * 3)
]
proposal_classifier_features = resnet_utils.stack_blocks_dense(
proposal_feature_maps, blocks)
return proposal_classifier_features
class FasterRCNNResnet50FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
"""Faster R-CNN Resnet 50 feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet50FeatureExtractor, self).__init__(
'resnet_v1_50', resnet_v1.resnet_v1_50, is_training,
first_stage_features_stride, reuse_weights, weight_decay)
class FasterRCNNResnet101FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
"""Faster R-CNN Resnet 101 feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet101FeatureExtractor, self).__init__(
'resnet_v1_101', resnet_v1.resnet_v1_101, is_training,
first_stage_features_stride, reuse_weights, weight_decay)
class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
"""Faster R-CNN Resnet 152 feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
reuse_weights: See base class.
weight_decay: See base class.
Raises:
ValueError: If `first_stage_features_stride` is not 8 or 16,
or if `architecture` is not supported.
"""
super(FasterRCNNResnet152FeatureExtractor, self).__init__(
'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
first_stage_features_stride, reuse_weights, weight_decay)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.models.faster_rcnn_resnet_v1_feature_extractor."""
import numpy as np
import tensorflow as tf
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as faster_rcnn_resnet_v1
class FasterRcnnResnetV1FeatureExtractorTest(tf.test.TestCase):
def _build_feature_extractor(self,
first_stage_features_stride,
architecture='resnet_v1_101'):
feature_extractor_map = {
'resnet_v1_50':
faster_rcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'resnet_v1_101':
faster_rcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
'resnet_v1_152':
faster_rcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor
}
return feature_extractor_map[architecture](
is_training=False,
first_stage_features_stride=first_stage_features_stride,
reuse_weights=None,
weight_decay=0.0)
def test_extract_proposal_features_returns_expected_size(self):
for architecture in ['resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152']:
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16, architecture=architecture)
preprocessed_inputs = tf.random_uniform(
[4, 224, 224, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 14, 14, 1024])
def test_extract_proposal_features_stride_eight(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=8)
preprocessed_inputs = tf.random_uniform(
[4, 224, 224, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [4, 28, 28, 1024])
def test_extract_proposal_features_half_size_input(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.random_uniform(
[1, 112, 112, 3], maxval=255, dtype=tf.float32)
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [1, 7, 7, 1024])
def test_extract_proposal_features_dies_on_invalid_stride(self):
with self.assertRaises(ValueError):
self._build_feature_extractor(first_stage_features_stride=99)
def test_extract_proposal_features_dies_on_very_small_images(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.placeholder(tf.float32, (4, None, None, 3))
rpn_feature_map = feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
features_shape = tf.shape(rpn_feature_map)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(
features_shape,
feed_dict={preprocessed_inputs: np.random.rand(4, 32, 32, 3)})
def test_extract_proposal_features_dies_with_incorrect_rank_inputs(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
preprocessed_inputs = tf.random_uniform(
[224, 224, 3], maxval=255, dtype=tf.float32)
with self.assertRaises(ValueError):
feature_extractor.extract_proposal_features(
preprocessed_inputs, scope='TestScope')
def test_extract_box_classifier_features_returns_expected_size(self):
feature_extractor = self._build_feature_extractor(
first_stage_features_stride=16)
proposal_feature_maps = tf.random_uniform(
[3, 7, 7, 1024], maxval=255, dtype=tf.float32)
proposal_classifier_features = (
feature_extractor.extract_box_classifier_features(
proposal_feature_maps, scope='TestScope'))
features_shape = tf.shape(proposal_classifier_features)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
features_shape_out = sess.run(features_shape)
self.assertAllEqual(features_shape_out, [3, 7, 7, 2048])
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to generate a list of feature maps based on image features.
Provides several feature map generators that can be used to build object
detection feature extractors.
Object detection feature extractors usually are built by stacking two components
- A base feature extractor such as Inception V3 and a feature map generator.
Feature map generators build on the base feature extractors and produce a list
of final feature maps.
"""
import collections
import tensorflow as tf
from object_detection.utils import ops
slim = tf.contrib.slim
def get_depth_fn(depth_multiplier, min_depth):
"""Builds a callable to compute depth (output channels) of conv filters.
Args:
depth_multiplier: a multiplier for the nominal depth.
min_depth: a lower bound on the depth of filters.
Returns:
A callable that takes in a nominal depth and returns the depth to use.
"""
def multiply_depth(depth):
new_depth = int(depth * depth_multiplier)
return max(new_depth, min_depth)
return multiply_depth
def multi_resolution_feature_maps(feature_map_layout, depth_multiplier,
min_depth, insert_1x1_conv, image_features):
"""Generates multi resolution feature maps from input image features.
Generates multi-scale feature maps for detection as in the SSD papers by
Liu et al: https://arxiv.org/pdf/1512.02325v2.pdf, See Sec 2.1.
More specifically, it performs the following two tasks:
1) If a layer name is provided in the configuration, returns that layer as a
feature map.
2) If a layer name is left as an empty string, constructs a new feature map
based on the spatial shape and depth configuration. Note that the current
implementation only supports generating new layers using convolution of
stride 2 resulting in a spatial resolution reduction by a factor of 2.
An example of the configuration for Inception V3:
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128],
'anchor_strides': [16, 32, 64, -1, -1, -1]
}
Args:
feature_map_layout: Dictionary of specifications for the feature map
layouts in the following format (Inception V2/V3 respectively):
{
'from_layer': ['Mixed_3c', 'Mixed_4c', 'Mixed_5c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128],
'anchor_strides': [16, 32, 64, -1, -1, -1]
}
or
{
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128],
'anchor_strides': [16, 32, 64, -1, -1, -1]
}
If 'from_layer' is specified, the specified feature map is directly used
as a box predictor layer, and the layer_depth is directly infered from the
feature map (instead of using the provided 'layer_depth' parameter). In
this case, our convention is to set 'layer_depth' to -1 for clarity.
Otherwise, if 'from_layer' is an empty string, then the box predictor
layer will be built from the previous layer using convolution operations.
Note that the current implementation only supports generating new layers
using convolutions of stride 2 (resulting in a spatial resolution
reduction by a factor of 2), and will be extended to a more flexible
design. Finally, the optional 'anchor_strides' can be used to specify the
anchor stride at each layer where 'from_layer' is specified. Our
convention is to set 'anchor_strides' to -1 whenever at the positions that
'from_layer' is an empty string, and anchor strides at these layers will
be inferred from the previous layer's anchor strides and the current
layer's stride length. In the case where 'anchor_strides' is not
specified, the anchor strides will default to the image width and height
divided by the number of anchors.
depth_multiplier: Depth multiplier for convolutional layers.
min_depth: Minimum depth for convolutional layers.
insert_1x1_conv: A boolean indicating whether an additional 1x1 convolution
should be inserted before shrinking the feature map.
image_features: A dictionary of handles to activation tensors from the
base feature extractor.
Returns:
feature_maps: an OrderedDict mapping keys (feature map names) to
tensors where each tensor has shape [batch, height_i, width_i, depth_i].
Raises:
ValueError: if the number entries in 'from_layer' and
'layer_depth' do not match.
ValueError: if the generated layer does not have the same resolution
as specified.
"""
depth_fn = get_depth_fn(depth_multiplier, min_depth)
feature_map_keys = []
feature_maps = []
base_from_layer = ''
feature_map_strides = None
use_depthwise = False
if 'anchor_strides' in feature_map_layout:
feature_map_strides = (feature_map_layout['anchor_strides'])
if 'use_depthwise' in feature_map_layout:
use_depthwise = feature_map_layout['use_depthwise']
for index, (from_layer, layer_depth) in enumerate(
zip(feature_map_layout['from_layer'], feature_map_layout['layer_depth'])):
if from_layer:
feature_map = image_features[from_layer]
base_from_layer = from_layer
feature_map_keys.append(from_layer)
else:
pre_layer = feature_maps[-1]
intermediate_layer = pre_layer
if insert_1x1_conv:
layer_name = '{}_1_Conv2d_{}_1x1_{}'.format(
base_from_layer, index, depth_fn(layer_depth / 2))
intermediate_layer = slim.conv2d(
pre_layer,
depth_fn(layer_depth / 2), [1, 1],
padding='SAME',
stride=1,
scope=layer_name)
stride = 2
layer_name = '{}_2_Conv2d_{}_3x3_s2_{}'.format(
base_from_layer, index, depth_fn(layer_depth))
if use_depthwise:
feature_map = slim.separable_conv2d(
ops.pad_to_multiple(intermediate_layer, stride),
None, [3, 3],
depth_multiplier=1,
padding='SAME',
stride=stride,
scope=layer_name + '_depthwise')
feature_map = slim.conv2d(
feature_map,
depth_fn(layer_depth), [1, 1],
padding='SAME',
stride=1,
scope=layer_name)
else:
feature_map = slim.conv2d(
ops.pad_to_multiple(intermediate_layer, stride),
depth_fn(layer_depth), [3, 3],
padding='SAME',
stride=stride,
scope=layer_name)
if (index > 0 and feature_map_strides and
feature_map_strides[index - 1] > 0):
feature_map_strides[index] = (
stride * feature_map_strides[index - 1])
feature_map_keys.append(layer_name)
feature_maps.append(feature_map)
return collections.OrderedDict(
[(x, y) for (x, y) in zip(feature_map_keys, feature_maps)])
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for feature map generators."""
import tensorflow as tf
from object_detection.models import feature_map_generators
INCEPTION_V2_LAYOUT = {
'from_layer': ['Mixed_3c', 'Mixed_4c', 'Mixed_5c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 256],
'anchor_strides': [16, 32, 64, -1, -1, -1],
'layer_target_norm': [20.0, -1, -1, -1, -1, -1],
}
INCEPTION_V3_LAYOUT = {
'from_layer': ['Mixed_5d', 'Mixed_6e', 'Mixed_7c', '', '', ''],
'layer_depth': [-1, -1, -1, 512, 256, 128],
'anchor_strides': [16, 32, 64, -1, -1, -1],
'aspect_ratios': [1.0, 2.0, 1.0/2, 3.0, 1.0/3]
}
# TODO: add tests with different anchor strides.
class MultiResolutionFeatureMapGeneratorTest(tf.test.TestCase):
def test_get_expected_feature_map_shapes_with_inception_v2(self):
image_features = {
'Mixed_3c': tf.random_uniform([4, 28, 28, 256], dtype=tf.float32),
'Mixed_4c': tf.random_uniform([4, 14, 14, 576], dtype=tf.float32),
'Mixed_5c': tf.random_uniform([4, 7, 7, 1024], dtype=tf.float32)
}
feature_maps = feature_map_generators.multi_resolution_feature_maps(
feature_map_layout=INCEPTION_V2_LAYOUT,
depth_multiplier=1,
min_depth=32,
insert_1x1_conv=True,
image_features=image_features)
expected_feature_map_shapes = {
'Mixed_3c': (4, 28, 28, 256),
'Mixed_4c': (4, 14, 14, 576),
'Mixed_5c': (4, 7, 7, 1024),
'Mixed_5c_2_Conv2d_3_3x3_s2_512': (4, 4, 4, 512),
'Mixed_5c_2_Conv2d_4_3x3_s2_256': (4, 2, 2, 256),
'Mixed_5c_2_Conv2d_5_3x3_s2_256': (4, 1, 1, 256)}
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
out_feature_maps = sess.run(feature_maps)
out_feature_map_shapes = dict(
(key, value.shape) for key, value in out_feature_maps.iteritems())
self.assertDictEqual(out_feature_map_shapes, expected_feature_map_shapes)
def test_get_expected_feature_map_shapes_with_inception_v3(self):
image_features = {
'Mixed_5d': tf.random_uniform([4, 35, 35, 256], dtype=tf.float32),
'Mixed_6e': tf.random_uniform([4, 17, 17, 576], dtype=tf.float32),
'Mixed_7c': tf.random_uniform([4, 8, 8, 1024], dtype=tf.float32)
}
feature_maps = feature_map_generators.multi_resolution_feature_maps(
feature_map_layout=INCEPTION_V3_LAYOUT,
depth_multiplier=1,
min_depth=32,
insert_1x1_conv=True,
image_features=image_features)
expected_feature_map_shapes = {
'Mixed_5d': (4, 35, 35, 256),
'Mixed_6e': (4, 17, 17, 576),
'Mixed_7c': (4, 8, 8, 1024),
'Mixed_7c_2_Conv2d_3_3x3_s2_512': (4, 4, 4, 512),
'Mixed_7c_2_Conv2d_4_3x3_s2_256': (4, 2, 2, 256),
'Mixed_7c_2_Conv2d_5_3x3_s2_128': (4, 1, 1, 128)}
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
out_feature_maps = sess.run(feature_maps)
out_feature_map_shapes = dict(
(key, value.shape) for key, value in out_feature_maps.iteritems())
self.assertDictEqual(out_feature_map_shapes, expected_feature_map_shapes)
class GetDepthFunctionTest(tf.test.TestCase):
def test_return_min_depth_when_multiplier_is_small(self):
depth_fn = feature_map_generators.get_depth_fn(depth_multiplier=0.5,
min_depth=16)
self.assertEqual(depth_fn(16), 16)
def test_return_correct_depth_with_multiplier(self):
depth_fn = feature_map_generators.get_depth_fn(depth_multiplier=0.5,
min_depth=16)
self.assertEqual(depth_fn(64), 32)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base test class SSDFeatureExtractors."""
from abc import abstractmethod
import numpy as np
import tensorflow as tf
class SsdFeatureExtractorTestBase(object):
def _validate_features_shape(self,
feature_extractor,
preprocessed_inputs,
expected_feature_map_shapes):
"""Checks the extracted features are of correct shape.
Args:
feature_extractor: The feature extractor to test.
preprocessed_inputs: A [batch, height, width, 3] tensor to extract
features with.
expected_feature_map_shapes: The expected shape of the extracted features.
"""
feature_maps = feature_extractor.extract_features(preprocessed_inputs)
feature_map_shapes = [tf.shape(feature_map) for feature_map in feature_maps]
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
feature_map_shapes_out = sess.run(feature_map_shapes)
for shape_out, exp_shape_out in zip(
feature_map_shapes_out, expected_feature_map_shapes):
self.assertAllEqual(shape_out, exp_shape_out)
@abstractmethod
def _create_feature_extractor(self, depth_multiplier):
"""Constructs a new feature extractor.
Args:
depth_multiplier: float depth multiplier for feature extractor
Returns:
an ssd_meta_arch.SSDFeatureExtractor object.
"""
pass
def check_extract_features_returns_correct_shape(
self,
image_height,
image_width,
depth_multiplier,
expected_feature_map_shapes_out):
feature_extractor = self._create_feature_extractor(depth_multiplier)
preprocessed_inputs = tf.random_uniform(
[4, image_height, image_width, 3], dtype=tf.float32)
self._validate_features_shape(
feature_extractor, preprocessed_inputs, expected_feature_map_shapes_out)
def check_extract_features_raises_error_with_invalid_image_size(
self,
image_height,
image_width,
depth_multiplier):
feature_extractor = self._create_feature_extractor(depth_multiplier)
preprocessed_inputs = tf.placeholder(tf.float32, (4, None, None, 3))
feature_maps = feature_extractor.extract_features(preprocessed_inputs)
test_preprocessed_image = np.random.rand(4, image_height, image_width, 3)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(feature_maps,
feed_dict={preprocessed_inputs: test_preprocessed_image})
def check_feature_extractor_variables_under_scope(self,
depth_multiplier,
scope_name):
g = tf.Graph()
with g.as_default():
feature_extractor = self._create_feature_extractor(depth_multiplier)
preprocessed_inputs = tf.placeholder(tf.float32, (4, None, None, 3))
feature_extractor.extract_features(preprocessed_inputs)
variables = g.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
for variable in variables:
self.assertTrue(variable.name.startswith(scope_name))
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SSDFeatureExtractor for InceptionV2 features."""
import tensorflow as tf
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import feature_map_generators
from nets import inception_v2
slim = tf.contrib.slim
class SSDInceptionV2FeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
"""SSD Feature Extractor using InceptionV2 features."""
def __init__(self,
depth_multiplier,
min_depth,
conv_hyperparams,
reuse_weights=None):
"""InceptionV2 Feature Extractor for SSD Models.
Args:
depth_multiplier: float depth multiplier for feature extractor.
min_depth: minimum feature extractor depth.
conv_hyperparams: tf slim arg_scope for conv2d and separable_conv2d ops.
reuse_weights: Whether to reuse variables. Default is None.
"""
super(SSDInceptionV2FeatureExtractor, self).__init__(
depth_multiplier, min_depth, conv_hyperparams, reuse_weights)
def preprocess(self, resized_inputs):
"""SSD preprocessing.
Maps pixel values to the range [-1, 1].
Args:
resized_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
"""
return (2.0 / 255.0) * resized_inputs - 1.0
def extract_features(self, preprocessed_inputs):
"""Extract features from preprocessed inputs.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
"""
preprocessed_inputs.get_shape().assert_has_rank(4)
shape_assert = tf.Assert(
tf.logical_and(tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33),
tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)),
['image size must at least be 33 in both height and width.'])
feature_map_layout = {
'from_layer': ['Mixed_4c', 'Mixed_5c', '', '', '', ''],
'layer_depth': [-1, -1, 512, 256, 256, 128],
}
with tf.control_dependencies([shape_assert]):
with slim.arg_scope(self._conv_hyperparams):
with tf.variable_scope('InceptionV2',
reuse=self._reuse_weights) as scope:
_, image_features = inception_v2.inception_v2_base(
preprocessed_inputs,
final_endpoint='Mixed_5c',
min_depth=self._min_depth,
depth_multiplier=self._depth_multiplier,
scope=scope)
feature_maps = feature_map_generators.multi_resolution_feature_maps(
feature_map_layout=feature_map_layout,
depth_multiplier=self._depth_multiplier,
min_depth=self._min_depth,
insert_1x1_conv=True,
image_features=image_features)
return feature_maps.values()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment