Unverified Commit 9bbf8015 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Merged commit includes the following changes: (#6932)

250447559  by Zhichao Lu:

    Update expected files format for Instance Segmentation challenge:
    - add fields ImageWidth, ImageHeight and store the values per prediction
    - as mask, store only encoded image and assume its size is ImageWidth x ImageHeight

--
250402780  by rathodv:

    Fix failing Mask R-CNN TPU convergence test.

    Cast second stage prediction tensors from bfloat16 to float32 to prevent errors in third target assignment (Mask Prediction) - Concat with different types bfloat16 and bfloat32 isn't allowed.

--
250300240  by Zhichao Lu:

    Addion Open Images Challenge 2019 object detection and instance segmentation
    support into Estimator framework.

--
249944839  by rathodv:

    Modify exporter.py to add multiclass score nodes in exported inference graphs.

--
249935201  by rathodv:

    Modify postprocess methods to preserve multiclass scores after non max suppression.

--
249878079  by Zhichao Lu:

    This CL slightly refactors some Object Detection helper functions for data creation, evaluation, and groundtruth providing.

    This will allow the eager+function custom loops to share code with the existing estimator training loops.

    Concretely we make the following changes:
    1. In input creation we separate dataset-creation into top-level helpers, and allow it to optionally accept a pre-constructed model directly instead of always creating a model from the config just for feature preprocessing.

    2. In coco evaluation we split the update_op creation into its own function, which the custom loops will call directly.

    3. In model_lib we move groundtruth providing/ datastructure munging into a helper function

    4. For now we put an escape hatch in `_summarize_target_assignment` when executing in tf v2.0 behavior because the summary apis used only work w/ tf 1.x

--
249673507  by rathodv:

    Use explicit casts instead of tf.to_float and tf.to_int32 to avoid warnings.

--
249656006  by Zhichao Lu:

    Add named "raw_keypoint_locations" node that corresponds with the "raw_box_locations" node.

--
249651674  by rathodv:

    Keep proposal boxes in float format. MatMulCropAndResize can handle the type even when feature themselves are bfloat16s.

--
249568633  by rathodv:

    Support q > 1 in class agnostic NMS.
    Break post_processing_test.py into 3 separate files to avoid linter errors.

--
249535530  by rathodv:

    Update some deprecated arguments to tf ops.

--
249368223  by rathodv:

    Modify MatMulCropAndResize to use MultiLevelRoIAlign method and move the tests to spatial_transform_ops.py module.

    This cl establishes that CropAndResize and RoIAlign are equivalent and only differ in the sampling point grid within the boxes. CropAndResize uses a uniform size x size point grid such that the corner points exactly overlap box corners, while RoiAlign divides boxes into size x size cells and uses their centers as sampling points. In this cl, we switch MatMulCropAndResize to use the MultiLevelRoIAlign implementation with `align_corner` option as MultiLevelRoIAlign implementation is more memory efficient on TPU when compared to the original MatMulCropAndResize.

--
249337338  by chowdhery:

    Add class-agnostic non-max-suppression in post_processing

--
249139196  by Zhichao Lu:

    Fix positional argument bug in export_tflite_ssd_graph

--
249120219  by Zhichao Lu:

    Add evaluator for computing precision limited to a given recall range.

--
249030593  by Zhichao Lu:

    Evaluation util to run segmentation and detection challenge evaluation.

--
248554358  by Zhichao Lu:

    This change contains the auxiliary changes required for TF 2.0 style training with eager+functions+dist strat loops, but not the loops themselves.

    It includes:
    - Updates to shape usage to support both tensorshape v1 and tensorshape v2
    - A fix to FreezableBatchNorm to not override the `training` arg in call when `None` was passed to the constructor (Not an issue in the estimator loops but it was in the custom loops)
    - Puts some constants in init_scope so they work in eager + functions
    - Makes learning rate schedules return a callable in eager mode (required so they update when the global_step changes)
    - Makes DetectionModel a tf.module so it tracks variables (e.g. ones nested in layers)
    - Removes some references to `op.name` for some losses and replaces it w/ explicit names
    - A small part of the change to allow the coco evaluation metrics to work in eager mode

--
248271226  by rathodv:

    Add MultiLevel RoIAlign op.

--
248229103  by rathodv:

    Add functions to 1. pad features maps 2. ravel 5-D indices

--
248206769  by rathodv:

    Add utilities needed to introduce RoI Align op.

--
248177733  by pengchong:

    Internal changes

--
247742582  by Zhichao Lu:

    Open Images Challenge 2019 instance segmentation metric: part 2

--
247525401  by Zhichao Lu:

    Update comments on max_class_per_detection.

--
247520753  by rathodv:

    Add multilevel crop and resize operation that builds on top of matmul_crop_and_resize.

--
247391600  by Zhichao Lu:

    Open Images Challenge 2019 instance segmentation metric

--
247325813  by chowdhery:

    Quantized MobileNet v2 SSD FPNLite config with depth multiplier 0.75

--

PiperOrigin-RevId: 250447559
parent f42fddee
......@@ -22,6 +22,20 @@ message BatchNonMaxSuppression {
// Whether to use the implementation of NMS that guarantees static shapes.
optional bool use_static_shapes = 6 [default = false];
// Whether to use class agnostic NMS.
// Class-agnostic NMS function implements a class-agnostic version
// of Non Maximal Suppression where if max_classes_per_detection=k,
// 1) we keep the top-k scores for each detection and
// 2) during NMS, each detection only uses the highest class score for sorting.
// 3) Compared to regular NMS, the worst runtime of this version is O(N^2)
// instead of O(KN^2) where N is the number of detections and K the number of
// classes.
optional bool use_class_agnostic_nms = 7 [default = false];
// Number of classes retained per detection in class agnostic NMS.
optional int32 max_classes_per_detection = 8 [default = 1];
}
// Configuration proto for post-processing predicted boxes and
......
......@@ -87,7 +87,7 @@ def get_prediction_tensor_shapes(pipeline_config):
_, input_tensors = exporter.input_placeholder_fn_map['image_tensor']()
inputs = tf.to_float(input_tensors)
inputs = tf.cast(input_tensors, dtype=tf.float32)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
prediction_dict = detection_model.predict(preprocessed_inputs,
......@@ -125,7 +125,7 @@ def build_graph(pipeline_config,
exporter.input_placeholder_fn_map[input_type]()
# CPU pre-processing
inputs = tf.to_float(input_tensors)
inputs = tf.cast(input_tensors, dtype=tf.float32)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
# Dimshuffle: [b, h, w, c] -> [b, c, h, w]
......
......@@ -57,7 +57,7 @@ def get_prediction_tensor_shapes(pipeline_config):
detection_model = model_builder.build(
pipeline_config.model, is_training=False)
_, input_tensors = exporter.input_placeholder_fn_map['image_tensor']()
inputs = tf.to_float(input_tensors)
inputs = tf.cast(input_tensors, dtype=tf.float32)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
prediction_dict = detection_model.predict(preprocessed_inputs,
true_image_shapes)
......@@ -138,7 +138,7 @@ def build_graph(pipeline_config,
placeholder_tensor, input_tensors = \
exporter.input_placeholder_fn_map[input_type]()
inputs = tf.to_float(input_tensors)
inputs = tf.cast(input_tensors, dtype=tf.float32)
preprocessed_inputs, true_image_shapes = detection_model.preprocess(inputs)
# Dimshuffle: (b, h, w, c) -> (b, c, h, w)
......
......@@ -47,20 +47,34 @@ def exponential_decay_with_burnin(global_step,
staircase: whether use staircase decay.
Returns:
a (scalar) float tensor representing learning rate
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
"""
if burnin_learning_rate == 0:
burnin_learning_rate = learning_rate_base
post_burnin_learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step - burnin_steps,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=staircase)
return tf.maximum(tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
tf.constant(burnin_learning_rate),
post_burnin_learning_rate), min_learning_rate, name='learning_rate')
def eager_decay_rate():
"""Callable to compute the learning rate."""
post_burnin_learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step - burnin_steps,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=staircase)
if callable(post_burnin_learning_rate):
post_burnin_learning_rate = post_burnin_learning_rate()
return tf.maximum(tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
tf.constant(burnin_learning_rate),
post_burnin_learning_rate), min_learning_rate, name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
def cosine_decay_with_warmup(global_step,
......@@ -88,7 +102,11 @@ def cosine_decay_with_warmup(global_step,
before decaying.
Returns:
a (scalar) float tensor representing learning rate.
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
Raises:
ValueError: if warmup_learning_rate is larger than learning_rate_base,
......@@ -97,24 +115,32 @@ def cosine_decay_with_warmup(global_step,
if total_steps < warmup_steps:
raise ValueError('total_steps must be larger or equal to '
'warmup_steps.')
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
np.pi *
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
if hold_base_rate_steps > 0:
learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
learning_rate)
return tf.where(global_step > total_steps, 0.0, learning_rate,
name='learning_rate')
def eager_decay_rate():
"""Callable to compute the learning rate."""
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
np.pi *
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
if hold_base_rate_steps > 0:
learning_rate = tf.where(
global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
learning_rate)
return tf.where(global_step > total_steps, 0.0, learning_rate,
name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
def manual_stepping(global_step, boundaries, rates, warmup=False):
......@@ -138,7 +164,11 @@ def manual_stepping(global_step, boundaries, rates, warmup=False):
[0, boundaries[0]].
Returns:
a (scalar) float tensor representing learning rate
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
Raises:
ValueError: if one of the following checks fails:
1. boundaries is a strictly increasing list of positive integers
......@@ -168,8 +198,16 @@ def manual_stepping(global_step, boundaries, rates, warmup=False):
else:
boundaries = [0] + boundaries
num_boundaries = len(boundaries)
rate_index = tf.reduce_max(tf.where(tf.greater_equal(global_step, boundaries),
list(range(num_boundaries)),
[0] * num_boundaries))
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
name='learning_rate')
def eager_decay_rate():
"""Callable to compute the learning rate."""
rate_index = tf.reduce_max(tf.where(
tf.greater_equal(global_step, boundaries),
list(range(num_boundaries)),
[0] * num_boundaries))
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
......@@ -70,6 +70,26 @@ class DetectionEvaluator(object):
"""
self._categories = categories
def observe_result_dict_for_single_example(self, eval_dict):
"""Observes an evaluation result dict for a single example.
When executing eagerly, once all observations have been observed by this
method you can use `.evaluate()` to get the final metrics.
When using `tf.estimator.Estimator` for evaluation this function is used by
`get_estimator_eval_metric_ops()` to construct the metric update op.
Args:
eval_dict: A dictionary that holds tensors for evaluating an object
detection model, returned from
eval_util.result_dict_for_single_example().
Returns:
None when executing eagerly, or an update_op that can be used to update
the eval metrics in `tf.estimator.EstimatorSpec`.
"""
raise NotImplementedError('Not implemented for this evaluator!')
@abstractmethod
def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
"""Adds groundtruth for a single image to be used for evaluation.
......@@ -126,6 +146,8 @@ class ObjectDetectionEvaluator(DetectionEvaluator):
def __init__(self,
categories,
matching_iou_threshold=0.5,
recall_lower_bound=0.0,
recall_upper_bound=1.0,
evaluate_corlocs=False,
evaluate_precision_recall=False,
metric_prefix=None,
......@@ -140,6 +162,8 @@ class ObjectDetectionEvaluator(DetectionEvaluator):
'name': (required) string representing category name e.g., 'cat', 'dog'.
matching_iou_threshold: IOU threshold to use for matching groundtruth
boxes to detection boxes.
recall_lower_bound: lower bound of recall operating area.
recall_upper_bound: upper bound of recall operating area.
evaluate_corlocs: (optional) boolean which determines if corloc scores
are to be returned or not.
evaluate_precision_recall: (optional) boolean which determines if
......@@ -166,6 +190,8 @@ class ObjectDetectionEvaluator(DetectionEvaluator):
if min(cat['id'] for cat in categories) < 1:
raise ValueError('Classes should be 1-indexed.')
self._matching_iou_threshold = matching_iou_threshold
self._recall_lower_bound = recall_lower_bound
self._recall_upper_bound = recall_upper_bound
self._use_weighted_mean_ap = use_weighted_mean_ap
self._label_id_offset = 1
self._evaluate_masks = evaluate_masks
......@@ -173,6 +199,8 @@ class ObjectDetectionEvaluator(DetectionEvaluator):
self._evaluation = ObjectDetectionEvaluation(
num_groundtruth_classes=self._num_classes,
matching_iou_threshold=self._matching_iou_threshold,
recall_lower_bound=self._recall_lower_bound,
recall_upper_bound=self._recall_upper_bound,
use_weighted_mean_ap=self._use_weighted_mean_ap,
label_id_offset=self._label_id_offset,
group_of_weight=self._group_of_weight)
......@@ -195,11 +223,18 @@ class ObjectDetectionEvaluator(DetectionEvaluator):
def _build_metric_names(self):
"""Builds a list with metric names."""
self._metric_names = [
self._metric_prefix + 'Precision/mAP@{}IOU'.format(
self._matching_iou_threshold)
]
if self._recall_lower_bound > 0.0 or self._recall_upper_bound < 1.0:
self._metric_names = [
self._metric_prefix +
'Precision/mAP@{}IOU@[{:.1f},{:.1f}]Recall'.format(
self._matching_iou_threshold, self._recall_lower_bound,
self._recall_upper_bound)
]
else:
self._metric_names = [
self._metric_prefix +
'Precision/mAP@{}IOU'.format(self._matching_iou_threshold)
]
if self._evaluate_corlocs:
self._metric_names.append(
self._metric_prefix +
......@@ -493,6 +528,24 @@ class WeightedPascalDetectionEvaluator(ObjectDetectionEvaluator):
use_weighted_mean_ap=True)
class PrecisionAtRecallDetectionEvaluator(ObjectDetectionEvaluator):
"""A class to evaluate detections using precision@recall metrics."""
def __init__(self,
categories,
matching_iou_threshold=0.5,
recall_lower_bound=0.0,
recall_upper_bound=1.0):
super(PrecisionAtRecallDetectionEvaluator, self).__init__(
categories,
matching_iou_threshold=matching_iou_threshold,
recall_lower_bound=recall_lower_bound,
recall_upper_bound=recall_upper_bound,
evaluate_corlocs=False,
metric_prefix='PrecisionAtRecallBoxes',
use_weighted_mean_ap=False)
class PascalInstanceSegmentationEvaluator(ObjectDetectionEvaluator):
"""A class to evaluate instance masks using PASCAL metrics."""
......@@ -540,6 +593,7 @@ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
def __init__(self,
categories,
matching_iou_threshold=0.5,
evaluate_masks=False,
evaluate_corlocs=False,
metric_prefix='OpenImagesV2',
group_of_weight=0.0):
......@@ -551,6 +605,7 @@ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
'name': (required) string representing category name e.g., 'cat', 'dog'.
matching_iou_threshold: IOU threshold to use for matching groundtruth
boxes to detection boxes.
evaluate_masks: if True, evaluator evaluates masks.
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
metric_prefix: Prefix name of the metric.
group_of_weight: Weight of the group-of bounding box. If set to 0 (default
......@@ -561,12 +616,14 @@ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
detection falls within a group-of box, weight group_of_weight is added
to false negatives.
"""
super(OpenImagesDetectionEvaluator, self).__init__(
categories,
matching_iou_threshold,
evaluate_corlocs,
metric_prefix=metric_prefix,
group_of_weight=group_of_weight)
group_of_weight=group_of_weight,
evaluate_masks=evaluate_masks)
self._expected_keys = set([
standard_fields.InputDataFields.key,
standard_fields.InputDataFields.groundtruth_boxes,
......@@ -576,6 +633,11 @@ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
standard_fields.DetectionResultFields.detection_scores,
standard_fields.DetectionResultFields.detection_classes,
])
if evaluate_masks:
self._expected_keys.add(
standard_fields.InputDataFields.groundtruth_instance_masks)
self._expected_keys.add(
standard_fields.DetectionResultFields.detection_masks)
def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
"""Adds groundtruth for a single image to be used for evaluation.
......@@ -617,17 +679,26 @@ class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
logging.warn(
'image %s does not have groundtruth group_of flag specified',
image_id)
if self._evaluate_masks:
groundtruth_masks = groundtruth_dict[
standard_fields.InputDataFields.groundtruth_instance_masks]
else:
groundtruth_masks = None
self._evaluation.add_single_ground_truth_image_info(
image_id,
groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],
groundtruth_classes,
groundtruth_is_difficult_list=None,
groundtruth_is_group_of_list=groundtruth_group_of)
groundtruth_is_group_of_list=groundtruth_group_of,
groundtruth_masks=groundtruth_masks)
self._image_ids.update([image_id])
class OpenImagesDetectionChallengeEvaluator(OpenImagesDetectionEvaluator):
"""A class implements Open Images Challenge Detection metrics.
class OpenImagesChallengeEvaluator(OpenImagesDetectionEvaluator):
"""A class implements Open Images Challenge metrics.
Both Detection and Instance Segmentation evaluation metrics are implemented.
Open Images Challenge Detection metric has two major changes in comparison
with Open Images V2 detection metric:
......@@ -637,10 +708,25 @@ class OpenImagesDetectionChallengeEvaluator(OpenImagesDetectionEvaluator):
evaluation: in case in image has neither positive nor negative image level
label of class c, all detections of this class on this image will be
ignored.
Open Images Challenge Instance Segmentation metric allows to measure per
formance of models in case of incomplete annotations: some instances are
annotations only on box level and some - on image-level. In addition,
image-level labels are taken into account as in detection metric.
Open Images Challenge Detection metric default parameters:
evaluate_masks = False
group_of_weight = 1.0
Open Images Challenge Instance Segmentation metric default parameters:
evaluate_masks = True
(group_of_weight will not matter)
"""
def __init__(self,
categories,
evaluate_masks=False,
matching_iou_threshold=0.5,
evaluate_corlocs=False,
group_of_weight=1.0):
......@@ -650,35 +736,34 @@ class OpenImagesDetectionChallengeEvaluator(OpenImagesDetectionEvaluator):
categories: A list of dicts, each of which has the following keys -
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
evaluate_masks: set to true for instance segmentation metric and to false
for detection metric.
matching_iou_threshold: IOU threshold to use for matching groundtruth
boxes to detection boxes.
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
group_of_weight: weight of a group-of box. If set to 0, detections of the
correct class within a group-of box are ignored. If weight is > 0
(default for Open Images Detection Challenge 2018), then if at least one
(default for Open Images Detection Challenge), then if at least one
detection falls within a group-of box with matching_iou_threshold,
weight group_of_weight is added to true positives. Consequently, if no
detection falls within a group-of box, weight group_of_weight is added
to false negatives.
"""
super(OpenImagesDetectionChallengeEvaluator, self).__init__(
if not evaluate_masks:
metrics_prefix = 'OpenImagesDetectionChallenge'
else:
metrics_prefix = 'OpenImagesInstanceSegmentationChallenge'
super(OpenImagesChallengeEvaluator, self).__init__(
categories,
matching_iou_threshold,
evaluate_corlocs,
metric_prefix='OpenImagesChallenge2018',
group_of_weight=group_of_weight)
evaluate_masks=evaluate_masks,
evaluate_corlocs=evaluate_corlocs,
group_of_weight=group_of_weight,
metric_prefix=metrics_prefix)
self._evaluatable_labels = {}
self._expected_keys = set([
standard_fields.InputDataFields.key,
standard_fields.InputDataFields.groundtruth_boxes,
standard_fields.InputDataFields.groundtruth_classes,
standard_fields.InputDataFields.groundtruth_group_of,
standard_fields.InputDataFields.groundtruth_image_classes,
standard_fields.DetectionResultFields.detection_boxes,
standard_fields.DetectionResultFields.detection_scores,
standard_fields.DetectionResultFields.detection_classes,
])
self._expected_keys.add(
standard_fields.InputDataFields.groundtruth_image_classes)
def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
"""Adds groundtruth for a single image to be used for evaluation.
......@@ -701,7 +786,7 @@ class OpenImagesDetectionChallengeEvaluator(OpenImagesDetectionEvaluator):
Raises:
ValueError: On adding groundtruth for an image more than once.
"""
super(OpenImagesDetectionChallengeEvaluator,
super(OpenImagesChallengeEvaluator,
self).add_single_ground_truth_image_info(image_id, groundtruth_dict)
groundtruth_classes = (
groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] -
......@@ -747,16 +832,22 @@ class OpenImagesDetectionChallengeEvaluator(OpenImagesDetectionEvaluator):
detected_scores = detections_dict[
standard_fields.DetectionResultFields.detection_scores][allowed_classes]
if self._evaluate_masks:
detection_masks = detections_dict[standard_fields.DetectionResultFields
.detection_masks][allowed_classes]
else:
detection_masks = None
self._evaluation.add_single_detected_image_info(
image_key=image_id,
detected_boxes=detected_boxes,
detected_scores=detected_scores,
detected_class_labels=detection_classes)
detected_class_labels=detection_classes,
detected_masks=detection_masks)
def clear(self):
"""Clears stored data."""
super(OpenImagesDetectionChallengeEvaluator, self).clear()
super(OpenImagesChallengeEvaluator, self).clear()
self._evaluatable_labels.clear()
......@@ -767,6 +858,73 @@ ObjectDetectionEvalMetrics = collections.namedtuple(
])
class OpenImagesDetectionChallengeEvaluator(OpenImagesChallengeEvaluator):
"""A class implements Open Images Detection Challenge metric."""
def __init__(self,
categories,
matching_iou_threshold=0.5,
evaluate_corlocs=False,
group_of_weight=1.0):
"""Constructor.
Args:
categories: A list of dicts, each of which has the following keys -
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
matching_iou_threshold: IOU threshold to use for matching groundtruth
boxes to detection boxes.
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
group_of_weight: weight of a group-of box. If set to 0, detections of the
correct class within a group-of box are ignored. If weight is > 0
(default for Open Images Detection Challenge), then if at least one
detection falls within a group-of box with matching_iou_threshold,
weight group_of_weight is added to true positives. Consequently, if no
detection falls within a group-of box, weight group_of_weight is added
to false negatives.
"""
super(OpenImagesDetectionChallengeEvaluator, self).__init__(
categories=categories,
evaluate_masks=False,
matching_iou_threshold=matching_iou_threshold,
evaluate_corlocs=False,
group_of_weight=1.0)
class OpenImagesInstanceSegmentationChallengeEvaluator(
OpenImagesChallengeEvaluator):
"""A class implements Open Images Instance Segmentation Challenge metric."""
def __init__(self,
categories,
matching_iou_threshold=0.5,
evaluate_corlocs=False,
group_of_weight=1.0):
"""Constructor.
Args:
categories: A list of dicts, each of which has the following keys -
'id': (required) an integer id uniquely identifying this category.
'name': (required) string representing category name e.g., 'cat', 'dog'.
matching_iou_threshold: IOU threshold to use for matching groundtruth
boxes to detection boxes.
evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
group_of_weight: weight of a group-of box. If set to 0, detections of the
correct class within a group-of box are ignored. If weight is > 0
(default for Open Images Detection Challenge), then if at least one
detection falls within a group-of box with matching_iou_threshold,
weight group_of_weight is added to true positives. Consequently, if no
detection falls within a group-of box, weight group_of_weight is added
to false negatives.
"""
super(OpenImagesInstanceSegmentationChallengeEvaluator, self).__init__(
categories=categories,
evaluate_masks=True,
matching_iou_threshold=matching_iou_threshold,
evaluate_corlocs=False,
group_of_weight=1.0)
class ObjectDetectionEvaluation(object):
"""Internal implementation of Pascal object detection metrics."""
......@@ -775,6 +933,8 @@ class ObjectDetectionEvaluation(object):
matching_iou_threshold=0.5,
nms_iou_threshold=1.0,
nms_max_output_boxes=10000,
recall_lower_bound=0.0,
recall_upper_bound=1.0,
use_weighted_mean_ap=False,
label_id_offset=0,
group_of_weight=0.0,
......@@ -788,6 +948,8 @@ class ObjectDetectionEvaluation(object):
nms_iou_threshold: IOU threshold used for non-maximum suppression.
nms_max_output_boxes: Maximum number of boxes returned by non-maximum
suppression.
recall_lower_bound: lower bound of recall operating area
recall_upper_bound: upper bound of recall operating area
use_weighted_mean_ap: (optional) boolean which determines if the mean
average precision is computed directly from the scores and tp_fp_labels
of all classes.
......@@ -813,6 +975,8 @@ class ObjectDetectionEvaluation(object):
nms_iou_threshold=nms_iou_threshold,
nms_max_output_boxes=nms_max_output_boxes,
group_of_weight=group_of_weight)
self.recall_lower_bound = recall_lower_bound
self.recall_upper_bound = recall_upper_bound
self.group_of_weight = group_of_weight
self.num_class = num_groundtruth_classes
self.use_weighted_mean_ap = use_weighted_mean_ap
......@@ -1036,10 +1200,17 @@ class ObjectDetectionEvaluation(object):
all_tp_fp_labels = np.append(all_tp_fp_labels, tp_fp_labels)
precision, recall = metrics.compute_precision_recall(
scores, tp_fp_labels, self.num_gt_instances_per_class[class_index])
self.precisions_per_class[class_index] = precision
self.recalls_per_class[class_index] = recall
average_precision = metrics.compute_average_precision(precision, recall)
recall_within_bound_indices = [
index for index, value in enumerate(recall) if
value >= self.recall_lower_bound and value <= self.recall_upper_bound
]
recall_within_bound = recall[recall_within_bound_indices]
precision_within_bound = precision[recall_within_bound_indices]
self.precisions_per_class[class_index] = precision_within_bound
self.recalls_per_class[class_index] = recall_within_bound
average_precision = metrics.compute_average_precision(
precision_within_bound, recall_within_bound)
self.average_precision_per_class[class_index] = average_precision
logging.info('average_precision: %f', average_precision)
......@@ -1051,7 +1222,14 @@ class ObjectDetectionEvaluation(object):
num_gt_instances = np.sum(self.num_gt_instances_per_class)
precision, recall = metrics.compute_precision_recall(
all_scores, all_tp_fp_labels, num_gt_instances)
mean_ap = metrics.compute_average_precision(precision, recall)
recall_within_bound_indices = [
index for index, value in enumerate(recall) if
value >= self.recall_lower_bound and value <= self.recall_upper_bound
]
recall_within_bound = recall[recall_within_bound_indices]
precision_within_bound = precision[recall_within_bound_indices]
mean_ap = metrics.compute_average_precision(precision_within_bound,
recall_within_bound)
else:
mean_ap = np.nanmean(self.average_precision_per_class)
mean_corloc = np.nanmean(self.corloc_per_class)
......
......@@ -101,9 +101,9 @@ class OpenImagesV2EvaluationTest(tf.test.TestCase):
self.assertFalse(oiv2_evaluator._image_ids)
class OpenImagesDetectionChallengeEvaluatorTest(tf.test.TestCase):
class OpenImagesChallengeEvaluatorTest(tf.test.TestCase):
def test_returns_correct_metric_values(self):
def test_returns_correct_detection_metric_values(self):
categories = [{
'id': 1,
'name': 'cat'
......@@ -115,8 +115,8 @@ class OpenImagesDetectionChallengeEvaluatorTest(tf.test.TestCase):
'name': 'elephant'
}]
oivchallenge_evaluator = (
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator(
categories, group_of_weight=0.5))
object_detection_evaluation.OpenImagesChallengeEvaluator(
categories, evaluate_masks=False, group_of_weight=0.5))
image_key = 'img1'
groundtruth_boxes = np.array(
......@@ -203,19 +203,124 @@ class OpenImagesDetectionChallengeEvaluatorTest(tf.test.TestCase):
detected_class_labels
})
metrics = oivchallenge_evaluator.evaluate()
expected_metric_name = 'OpenImagesDetectionChallenge'
self.assertAlmostEqual(
metrics['OpenImagesChallenge2018_PerformanceByCategory/AP@0.5IOU/dog'],
metrics[
expected_metric_name + '_PerformanceByCategory/AP@0.5IOU/dog'],
0.3333333333)
self.assertAlmostEqual(
metrics[
'OpenImagesChallenge2018_PerformanceByCategory/AP@0.5IOU/elephant'],
expected_metric_name + '_PerformanceByCategory/AP@0.5IOU/elephant'],
0.333333333333)
self.assertAlmostEqual(
metrics['OpenImagesChallenge2018_PerformanceByCategory/AP@0.5IOU/cat'],
metrics[
expected_metric_name + '_PerformanceByCategory/AP@0.5IOU/cat'],
0.142857142857)
self.assertAlmostEqual(
metrics['OpenImagesChallenge2018_Precision/mAP@0.5IOU'], 0.269841269)
metrics[expected_metric_name + '_Precision/mAP@0.5IOU'],
0.269841269)
oivchallenge_evaluator.clear()
self.assertFalse(oivchallenge_evaluator._image_ids)
def test_returns_correct_instance_segm_metric_values(self):
categories = [{'id': 1, 'name': 'cat'}, {'id': 2, 'name': 'dog'}]
oivchallenge_evaluator = (
object_detection_evaluation.OpenImagesChallengeEvaluator(
categories, evaluate_masks=True))
image_key = 'img1'
groundtruth_boxes = np.array([[0, 0, 1, 1], [0, 0, 2, 2], [0, 0, 3, 3]],
dtype=float)
groundtruth_class_labels = np.array([1, 2, 1], dtype=int)
groundtruth_is_group_of_list = np.array([False, False, True], dtype=bool)
groundtruth_verified_labels = np.array([1, 2, 3], dtype=int)
groundtruth_mask_0 = np.array([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
zero_mask = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0, zero_mask, zero_mask],
axis=0)
oivchallenge_evaluator.add_single_ground_truth_image_info(
image_key, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels,
standard_fields.InputDataFields.groundtruth_group_of:
groundtruth_is_group_of_list,
standard_fields.InputDataFields.groundtruth_image_classes:
groundtruth_verified_labels,
standard_fields.InputDataFields.groundtruth_instance_masks:
groundtruth_masks
})
image_key = 'img3'
groundtruth_boxes = np.array([[0, 0, 1, 1]], dtype=float)
groundtruth_class_labels = np.array([2], dtype=int)
groundtruth_mask_0 = np.array([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0], axis=0)
oivchallenge_evaluator.add_single_ground_truth_image_info(
image_key, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels,
standard_fields.InputDataFields.groundtruth_instance_masks:
groundtruth_masks
})
image_key = 'img1'
detected_boxes = np.array([[0, 0, 2, 2], [2, 2, 3, 3]], dtype=float)
detection_mask_0 = np.array([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detection_mask_0, zero_mask], axis=0)
detected_class_labels = np.array([2, 1], dtype=int)
detected_scores = np.array([0.7, 0.8], dtype=float)
oivchallenge_evaluator.add_single_detected_image_info(
image_key, {
standard_fields.DetectionResultFields.detection_boxes:
detected_boxes,
standard_fields.DetectionResultFields.detection_scores:
detected_scores,
standard_fields.DetectionResultFields.detection_classes:
detected_class_labels,
standard_fields.DetectionResultFields.detection_masks:
detected_masks
})
image_key = 'img3'
detected_boxes = np.array([[0, 0, 1, 1]], dtype=float)
detected_class_labels = np.array([2], dtype=int)
detected_scores = np.array([0.5], dtype=float)
detected_mask_0 = np.array([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detected_mask_0], axis=0)
oivchallenge_evaluator.add_single_detected_image_info(
image_key, {
standard_fields.DetectionResultFields.detection_boxes:
detected_boxes,
standard_fields.DetectionResultFields.detection_scores:
detected_scores,
standard_fields.DetectionResultFields.detection_classes:
detected_class_labels,
standard_fields.DetectionResultFields.detection_masks:
detected_masks
})
metrics = oivchallenge_evaluator.evaluate()
expected_metric_name = 'OpenImagesInstanceSegmentationChallenge'
self.assertAlmostEqual(
metrics[
expected_metric_name + '_PerformanceByCategory/AP@0.5IOU/dog'],
0.5)
self.assertAlmostEqual(
metrics[
expected_metric_name + '_PerformanceByCategory/AP@0.5IOU/cat'],
0)
self.assertAlmostEqual(
metrics[
expected_metric_name + '_Precision/mAP@0.5IOU'],
0.25)
oivchallenge_evaluator.clear()
self.assertFalse(oivchallenge_evaluator._image_ids)
......@@ -572,6 +677,157 @@ class WeightedPascalEvaluationTest(tf.test.TestCase):
groundtruth_class_labels1})
class PrecisionAtRecallEvaluationTest(tf.test.TestCase):
def setUp(self):
self.categories = [{
'id': 1,
'name': 'cat'
}, {
'id': 2,
'name': 'dog'
}, {
'id': 3,
'name': 'elephant'
}]
def create_and_add_common_ground_truth(self):
# Add groundtruth
self.wp_eval = (
object_detection_evaluation.PrecisionAtRecallDetectionEvaluator(
self.categories, recall_lower_bound=0.0, recall_upper_bound=0.5))
image_key1 = 'img1'
groundtruth_boxes1 = np.array([[0, 0, 1, 1], [0, 0, 2, 2], [0, 0, 3, 3]],
dtype=float)
groundtruth_class_labels1 = np.array([1, 3, 1], dtype=int)
self.wp_eval.add_single_ground_truth_image_info(
image_key1, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes1,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels1
})
# add 'img2' separately
image_key3 = 'img3'
groundtruth_boxes3 = np.array([[0, 0, 1, 1]], dtype=float)
groundtruth_class_labels3 = np.array([2], dtype=int)
self.wp_eval.add_single_ground_truth_image_info(
image_key3, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes3,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels3
})
def add_common_detected(self):
image_key = 'img2'
detected_boxes = np.array(
[[10, 10, 11, 11], [100, 100, 120, 120], [100, 100, 220, 220]],
dtype=float)
detected_class_labels = np.array([1, 1, 3], dtype=int)
detected_scores = np.array([0.7, 0.8, 0.9], dtype=float)
self.wp_eval.add_single_detected_image_info(
image_key, {
standard_fields.DetectionResultFields.detection_boxes:
detected_boxes,
standard_fields.DetectionResultFields.detection_scores:
detected_scores,
standard_fields.DetectionResultFields.detection_classes:
detected_class_labels
})
def test_returns_correct_metric_values(self):
self.create_and_add_common_ground_truth()
image_key2 = 'img2'
groundtruth_boxes2 = np.array(
[[10, 10, 11, 11], [500, 500, 510, 510], [10, 10, 12, 12]], dtype=float)
groundtruth_class_labels2 = np.array([1, 1, 3], dtype=int)
self.wp_eval.add_single_ground_truth_image_info(
image_key2, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes2,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels2
})
self.add_common_detected()
metrics = self.wp_eval.evaluate()
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/dog'], 0.0)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/elephant'], 0.0)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/cat'], 0.5 / 4)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'Precision/mAP@0.5IOU@[0.0,0.5]Recall'], 1. / (3 + 1 + 2) / 4)
self.wp_eval.clear()
self.assertFalse(self.wp_eval._image_ids)
def test_returns_correct_metric_values_with_difficult_list(self):
self.create_and_add_common_ground_truth()
image_key2 = 'img2'
groundtruth_boxes2 = np.array(
[[10, 10, 11, 11], [500, 500, 510, 510], [10, 10, 12, 12]], dtype=float)
groundtruth_class_labels2 = np.array([1, 1, 3], dtype=int)
groundtruth_is_difficult_list2 = np.array([False, True, False], dtype=bool)
self.wp_eval.add_single_ground_truth_image_info(
image_key2, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes2,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels2,
standard_fields.InputDataFields.groundtruth_difficult:
groundtruth_is_difficult_list2
})
self.add_common_detected()
metrics = self.wp_eval.evaluate()
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/dog'], 0.0)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/elephant'], 0.0)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'PerformanceByCategory/AP@0.5IOU/cat'], 0.5 / 3)
self.assertAlmostEqual(
metrics[self.wp_eval._metric_prefix +
'Precision/mAP@0.5IOU@[0.0,0.5]Recall'], 1. / (3 + 1 + 2) / 3)
self.wp_eval.clear()
self.assertFalse(self.wp_eval._image_ids)
def test_value_error_on_duplicate_images(self):
# Add groundtruth
self.wp_eval = (
object_detection_evaluation.PrecisionAtRecallDetectionEvaluator(
self.categories, recall_lower_bound=0.0, recall_upper_bound=0.5))
image_key1 = 'img1'
groundtruth_boxes1 = np.array([[0, 0, 1, 1], [0, 0, 2, 2], [0, 0, 3, 3]],
dtype=float)
groundtruth_class_labels1 = np.array([1, 3, 1], dtype=int)
self.wp_eval.add_single_ground_truth_image_info(
image_key1, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes1,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels1
})
with self.assertRaises(ValueError):
self.wp_eval.add_single_ground_truth_image_info(
image_key1, {
standard_fields.InputDataFields.groundtruth_boxes:
groundtruth_boxes1,
standard_fields.InputDataFields.groundtruth_classes:
groundtruth_class_labels1
})
class ObjectDetectionEvaluationTest(tf.test.TestCase):
def setUp(self):
......
......@@ -22,9 +22,15 @@ import tensorflow as tf
from object_detection.core import standard_fields as fields
from object_detection.utils import shape_utils
from object_detection.utils import spatial_transform_ops as spatial_ops
from object_detection.utils import static_shape
matmul_crop_and_resize = spatial_ops.matmul_crop_and_resize
multilevel_roi_align = spatial_ops.multilevel_roi_align
native_crop_and_resize = spatial_ops.native_crop_and_resize
def expanded_shape(orig_shape, start_dim, num_dims):
"""Inserts multiple ones into a shape vector.
......@@ -176,16 +182,22 @@ def pad_to_multiple(tensor, multiple):
if tensor_height is None:
tensor_height = tf.shape(tensor)[1]
padded_tensor_height = tf.to_int32(
tf.ceil(tf.to_float(tensor_height) / tf.to_float(multiple))) * multiple
padded_tensor_height = tf.cast(
tf.ceil(
tf.cast(tensor_height, dtype=tf.float32) /
tf.cast(multiple, dtype=tf.float32)),
dtype=tf.int32) * multiple
else:
padded_tensor_height = int(
math.ceil(float(tensor_height) / multiple) * multiple)
if tensor_width is None:
tensor_width = tf.shape(tensor)[2]
padded_tensor_width = tf.to_int32(
tf.ceil(tf.to_float(tensor_width) / tf.to_float(multiple))) * multiple
padded_tensor_width = tf.cast(
tf.ceil(
tf.cast(tensor_width, dtype=tf.float32) /
tf.cast(multiple, dtype=tf.float32)),
dtype=tf.int32) * multiple
else:
padded_tensor_width = int(
math.ceil(float(tensor_width) / multiple) * multiple)
......@@ -309,11 +321,11 @@ def indices_to_dense_vector(indices,
dense 1D Tensor of shape [size] with indices set to indices_values and the
rest set to default_value.
"""
size = tf.to_int32(size)
size = tf.cast(size, dtype=tf.int32)
zeros = tf.ones([size], dtype=dtype) * default_value
values = tf.ones_like(indices, dtype=dtype) * indices_value
return tf.dynamic_stitch([tf.range(size), tf.to_int32(indices)],
return tf.dynamic_stitch([tf.range(size), tf.cast(indices, dtype=tf.int32)],
[zeros, values])
......@@ -469,8 +481,8 @@ def filter_groundtruth_with_nan_box_coordinates(tensor_dict):
boxes.
"""
groundtruth_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes]
nan_indicator_vector = tf.greater(tf.reduce_sum(tf.to_int32(
tf.is_nan(groundtruth_boxes)), reduction_indices=[1]), 0)
nan_indicator_vector = tf.greater(tf.reduce_sum(tf.cast(
tf.is_nan(groundtruth_boxes), dtype=tf.int32), reduction_indices=[1]), 0)
valid_indicator_vector = tf.logical_not(nan_indicator_vector)
valid_indices = tf.where(valid_indicator_vector)
......@@ -576,7 +588,6 @@ def normalize_to_target(inputs,
trainable=trainable)
if summarize:
mean = tf.reduce_mean(target_norm)
mean = tf.Print(mean, ['NormalizeToTarget:', mean])
tf.summary.scalar(tf.get_variable_scope().name, mean)
lengths = epsilon + tf.sqrt(tf.reduce_sum(tf.square(inputs), dim, True))
mult_shape = input_rank*[1]
......@@ -754,7 +765,7 @@ def position_sensitive_crop_regions(image,
position_sensitive_features = tf.add_n(image_crops) / len(image_crops)
# Then average over spatial positions within the bins.
position_sensitive_features = tf.reduce_mean(
position_sensitive_features, [1, 2], keep_dims=True)
position_sensitive_features, [1, 2], keepdims=True)
else:
# Reorder height/width to depth channel.
block_size = bin_crop_size[0]
......@@ -770,7 +781,7 @@ def position_sensitive_crop_regions(image,
tf.batch_to_space_nd(position_sensitive_features,
block_shape=[1] + num_spatial_bins,
crops=tf.zeros((3, 2), dtype=tf.int32)),
squeeze_dims=[0])
axis=[0])
# Reorder back the depth channel.
if block_size >= 2:
......@@ -908,7 +919,7 @@ def merge_boxes_with_multiple_labels(boxes,
dtype=(tf.int64, tf.float32))
merged_boxes = tf.reshape(merged_boxes, [-1, 4])
class_encodings = tf.to_int32(class_encodings)
class_encodings = tf.cast(class_encodings, dtype=tf.int32)
class_encodings = tf.reshape(class_encodings, [-1, num_classes])
confidence_encodings = tf.reshape(confidence_encodings, [-1, num_classes])
merged_box_indices = tf.reshape(merged_box_indices, [-1])
......@@ -983,132 +994,39 @@ def matmul_gather_on_zeroth_axis(params, indices, scope=None):
tf.stack(indices_shape + params_shape[1:]))
def matmul_crop_and_resize(image, boxes, crop_size, scope=None):
"""Matrix multiplication based implementation of the crop and resize op.
Extracts crops from the input image tensor and bilinearly resizes them
(possibly with aspect ratio change) to a common output size specified by
crop_size. This is more general than the crop_to_bounding_box op which
extracts a fixed size slice from the input image and does not allow
resizing or aspect ratio change.
Returns a tensor with crops from the input image at positions defined at
the bounding box locations in boxes. The cropped boxes are all resized
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
The result is a 5-D tensor `[batch, num_boxes, crop_height, crop_width,
depth]`.
Running time complexity:
O((# channels) * (# boxes) * (crop_size)^2 * M), where M is the number
of pixels of the longer edge of the image.
Note that this operation is meant to replicate the behavior of the standard
tf.image.crop_and_resize operation but there are a few differences.
Specifically:
1) The extrapolation value (the values that are interpolated from outside
the bounds of the image window) is always zero
2) Only XLA supported operations are used (e.g., matrix multiplication).
3) There is no `box_indices` argument --- to run this op on multiple images,
one must currently call this op independently on each image.
4) The `crop_size` parameter is assumed to be statically defined.
Moreover, the number of boxes must be strictly nonzero.
def fpn_feature_levels(num_levels, unit_scale_index, image_ratio, boxes):
"""Returns fpn feature level for each box based on its area.
See section 4.2 of https://arxiv.org/pdf/1612.03144.pdf for details.
Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
`int16`, `int32`, `int64`, `half`, 'bfloat16', `float32`, `float64`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
Both `image_height` and `image_width` need to be positive.
boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
normalized coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the `[0, 1]` range are allowed, in which
case we use `extrapolation_value` to extrapolate the input image values.
crop_size: A list of two integers `[crop_height, crop_width]`. All
cropped image patches are resized to this size. The aspect ratio of the
image content is not preserved. Both `crop_height` and `crop_width` need
to be positive.
scope: A name for the operation (optional).
num_levels: An integer indicating the number of feature levels to crop boxes
from.
unit_scale_index: An 0-based integer indicating the index of feature map
which most closely matches the resolution of the pretrained model.
image_ratio: A float indicating the ratio of input image area to pretraining
image area.
boxes: A float tensor of shape [batch, num_boxes, 4] containing boxes of the
form [ymin, xmin, ymax, xmax] in normalized coordinates.
Returns:
A 5-D tensor of shape `[batch, num_boxes, crop_height, crop_width, depth]`
An int32 tensor of shape [batch_size, num_boxes] containing feature indices.
"""
img_shape = tf.shape(image)
img_height = img_shape[1]
img_width = img_shape[2]
def _lin_space_weights(num, img_size):
if num > 1:
start_weights = tf.linspace(tf.to_float(img_size) - 1.0, 0.0, num)
stop_weights = tf.to_float(img_size) - 1.0 - start_weights
else:
start_weights = tf.ones([num], dtype=tf.float32) * \
.5 * (tf.to_float(img_size) - 1.0)
stop_weights = tf.ones([num], dtype=tf.float32) * \
.5 * (tf.to_float(img_size) - 1.0)
return (start_weights, stop_weights)
with tf.name_scope(scope, 'MatMulCropAndResize'):
y1_weights, y2_weights = _lin_space_weights(crop_size[0], img_height)
x1_weights, x2_weights = _lin_space_weights(crop_size[1], img_width)
y1_weights = tf.cast(y1_weights, boxes.dtype)
y2_weights = tf.cast(y2_weights, boxes.dtype)
x1_weights = tf.cast(x1_weights, boxes.dtype)
x2_weights = tf.cast(x2_weights, boxes.dtype)
[y1, x1, y2, x2] = tf.unstack(boxes, axis=2)
# Pixel centers of input image and grid points along height and width
image_idx_h = tf.cast(
tf.reshape(tf.range(img_height), (1, 1, 1, img_height)),
dtype=boxes.dtype)
image_idx_w = tf.cast(
tf.reshape(tf.range(img_width), (1, 1, 1, img_width)),
dtype=boxes.dtype)
grid_pos_h = tf.expand_dims(
tf.einsum('ab,c->abc', y1, y1_weights) +
tf.einsum('ab,c->abc', y2, y2_weights),
axis=3)
grid_pos_w = tf.expand_dims(
tf.einsum('ab,c->abc', x1, x1_weights) +
tf.einsum('ab,c->abc', x2, x2_weights),
axis=3)
# Create kernel matrices of pairwise kernel evaluations between pixel
# centers of image and grid points.
kernel_h = tf.nn.relu(1 - tf.abs(image_idx_h - grid_pos_h))
kernel_w = tf.nn.relu(1 - tf.abs(image_idx_w - grid_pos_w))
# Compute matrix multiplication between
# the spatial dimensions of the image
# and height-wise kernel using einsum.
intermediate_image = tf.einsum('abci,aiop->abcop', kernel_h, image)
# Compute matrix multiplication between the spatial dimensions of the
# intermediate_image and width-wise kernel using einsum.
return tf.einsum('abno,abcop->abcnp', kernel_w, intermediate_image)
def native_crop_and_resize(image, boxes, crop_size, scope=None):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def get_box_inds(proposals):
proposals_shape = proposals.get_shape().as_list()
if any(dim is None for dim in proposals_shape):
proposals_shape = tf.shape(proposals)
ones_mat = tf.ones(proposals_shape[:2], dtype=tf.int32)
multiplier = tf.expand_dims(
tf.range(start=0, limit=proposals_shape[0]), 1)
return tf.reshape(ones_mat * multiplier, [-1])
with tf.name_scope(scope, 'CropAndResize'):
cropped_regions = tf.image.crop_and_resize(
image, tf.reshape(boxes, [-1] + boxes.shape.as_list()[2:]),
get_box_inds(boxes), crop_size)
final_shape = tf.concat([tf.shape(boxes)[:2],
tf.shape(cropped_regions)[1:]], axis=0)
return tf.reshape(cropped_regions, final_shape)
assert num_levels > 0, (
'`num_levels` must be > 0. Found {}'.format(num_levels))
assert unit_scale_index < num_levels and unit_scale_index >= 0, (
'`unit_scale_index` must be in [0, {}). Found {}.'.format(
num_levels, unit_scale_index))
box_height_width = boxes[:, :, 2:4] - boxes[:, :, 0:2]
areas_sqrt = tf.sqrt(tf.reduce_prod(box_height_width, axis=2))
log_2 = tf.cast(tf.log(2.0), dtype=boxes.dtype)
levels = tf.cast(
tf.floordiv(tf.log(areas_sqrt * image_ratio), log_2)
+
unit_scale_index,
dtype=tf.int32)
levels = tf.maximum(0, tf.minimum(num_levels - 1, levels))
return levels
def bfloat16_to_float32_nested(tensor_nested):
......
......@@ -851,7 +851,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
# work as the usual crop and resize for just one channel.
crop = tf.image.crop_and_resize(tf.expand_dims(image, axis=0), boxes,
box_ind, crop_size)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keep_dims=True)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keepdims=True)
ps_crop_and_pool = ops.position_sensitive_crop_regions(
tiled_image,
......@@ -937,8 +937,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
image, boxes, crop_size, num_spatial_bins, global_pool=False)
with self.test_session() as sess:
output = sess.run(ps_crop)
self.assertAllEqual(output, expected_output[crop_size_mult - 1])
self.assertAllClose(output, expected_output[crop_size_mult - 1])
def test_position_sensitive_with_global_pool_false_and_do_global_pool(self):
num_spatial_bins = [3, 2]
......@@ -981,7 +980,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
ps_crop = ops.position_sensitive_crop_regions(
image, boxes, crop_size, num_spatial_bins, global_pool=False)
ps_crop_and_pool = tf.reduce_mean(
ps_crop, reduction_indices=(1, 2), keep_dims=True)
ps_crop, reduction_indices=(1, 2), keepdims=True)
with self.test_session() as sess:
output = sess.run(ps_crop_and_pool)
......@@ -1349,154 +1348,31 @@ class MatmulGatherOnZerothAxis(test_case.TestCase):
self.assertAllClose(gather_output, expected_output)
class OpsTestMatMulCropAndResize(test_case.TestCase):
def testMatMulCropAndResize2x2To1x1(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To1x1Flipped(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To3x3(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[1.0], [1.5], [2.0]],
[[2.0], [2.5], [3.0]],
[[3.0], [3.5], [4.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To3x3Flipped(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[4.0], [3.5], [3.0]],
[[3.0], [2.5], [2.0]],
[[2.0], [1.5], [1.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1], [3]], [[7], [9]]],
[[[1], [2]], [[4], [5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testBatchMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2Flipped(self):
def graph_fn(image, boxes):
return ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[9], [7]], [[3], [1]]],
[[[5], [4]], [[2], [1]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testInvalidInputShape(self):
image = tf.constant([[[1], [2]], [[3], [4]]], dtype=tf.float32)
boxes = tf.constant([[-1, -1, 1, 1]], dtype=tf.float32)
crop_size = [4, 4]
with self.assertRaises(ValueError):
_ = ops.matmul_crop_and_resize(image, boxes, crop_size)
class OpsTestCropAndResize(test_case.TestCase):
def testBatchCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return ops.native_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute_cpu(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
class FpnFeatureLevelsTest(test_case.TestCase):
def test_correct_fpn_levels(self):
image_size = 640
pretraininig_image_size = 224
image_ratio = image_size * 1.0 / pretraininig_image_size
boxes = np.array(
[
[
[0, 0, 111, 111], # Level 0.
[0, 0, 113, 113], # Level 1.
[0, 0, 223, 223], # Level 1.
[0, 0, 225, 225], # Level 2.
[0, 0, 449, 449] # Level 3.
],
],
dtype=np.float32) / image_size
def graph_fn(boxes):
return ops.fpn_feature_levels(
num_levels=5, unit_scale_index=2, image_ratio=image_ratio,
boxes=boxes)
levels = self.execute(graph_fn, [boxes])
self.assertAllEqual([[0, 1, 1, 2, 3]], levels)
class TestBfloat16ToFloat32(test_case.TestCase):
......
......@@ -42,7 +42,7 @@ class PerImageEvaluation(object):
Args:
num_groundtruth_classes: Number of ground truth object classes
matching_iou_threshold: A ratio of area intersection to union, which is
the threshold to consider whether a detection is true positive or not
the threshold to consider whether a detection is true positive or not
nms_iou_threshold: IOU threshold used in Non Maximum Suppression.
nms_max_output_boxes: Number of maximum output boxes in NMS.
group_of_weight: Weight of the group-of boxes.
......@@ -53,11 +53,16 @@ class PerImageEvaluation(object):
self.num_groundtruth_classes = num_groundtruth_classes
self.group_of_weight = group_of_weight
def compute_object_detection_metrics(
self, detected_boxes, detected_scores, detected_class_labels,
groundtruth_boxes, groundtruth_class_labels,
groundtruth_is_difficult_list, groundtruth_is_group_of_list,
detected_masks=None, groundtruth_masks=None):
def compute_object_detection_metrics(self,
detected_boxes,
detected_scores,
detected_class_labels,
groundtruth_boxes,
groundtruth_class_labels,
groundtruth_is_difficult_list,
groundtruth_is_group_of_list,
detected_masks=None,
groundtruth_masks=None):
"""Evaluates detections as being tp, fp or weighted from a single image.
The evaluation is done in two stages:
......@@ -68,25 +73,24 @@ class PerImageEvaluation(object):
Args:
detected_boxes: A float numpy array of shape [N, 4], representing N
regions of detected object regions.
Each row is of the format [y_min, x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing
the confidence scores of the detected N object instances.
regions of detected object regions. Each row is of the format [y_min,
x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing the
confidence scores of the detected N object instances.
detected_class_labels: A integer numpy array of shape [N, 1], repreneting
the class labels of the detected N object instances.
the class labels of the detected N object instances.
groundtruth_boxes: A float numpy array of shape [M, 4], representing M
regions of object instances in ground truth
regions of object instances in ground truth
groundtruth_class_labels: An integer numpy array of shape [M, 1],
representing M class labels of object instances in ground truth
representing M class labels of object instances in ground truth
groundtruth_is_difficult_list: A boolean numpy array of length M denoting
whether a ground truth box is a difficult instance or not
whether a ground truth box is a difficult instance or not
groundtruth_is_group_of_list: A boolean numpy array of length M denoting
whether a ground truth box has group-of tag
detected_masks: (optional) A uint8 numpy array of shape
[N, height, width]. If not None, the metrics will be computed based
on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape
[M, height, width].
whether a ground truth box has group-of tag
detected_masks: (optional) A uint8 numpy array of shape [N, height,
width]. If not None, the metrics will be computed based on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape [M, height,
width]. Can have empty masks, i.e. where all values are 0.
Returns:
scores: A list of C float numpy arrays. Each numpy array is of
......@@ -124,29 +128,32 @@ class PerImageEvaluation(object):
return scores, tp_fp_labels, is_class_correctly_detected_in_image
def _compute_cor_loc(self, detected_boxes, detected_scores,
detected_class_labels, groundtruth_boxes,
groundtruth_class_labels, detected_masks=None,
def _compute_cor_loc(self,
detected_boxes,
detected_scores,
detected_class_labels,
groundtruth_boxes,
groundtruth_class_labels,
detected_masks=None,
groundtruth_masks=None):
"""Compute CorLoc score for object detection result.
Args:
detected_boxes: A float numpy array of shape [N, 4], representing N
regions of detected object regions.
Each row is of the format [y_min, x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing
the confidence scores of the detected N object instances.
regions of detected object regions. Each row is of the format [y_min,
x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing the
confidence scores of the detected N object instances.
detected_class_labels: A integer numpy array of shape [N, 1], repreneting
the class labels of the detected N object instances.
the class labels of the detected N object instances.
groundtruth_boxes: A float numpy array of shape [M, 4], representing M
regions of object instances in ground truth
regions of object instances in ground truth
groundtruth_class_labels: An integer numpy array of shape [M, 1],
representing M class labels of object instances in ground truth
detected_masks: (optional) A uint8 numpy array of shape
[N, height, width]. If not None, the scores will be computed based
on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape
[M, height, width].
representing M class labels of object instances in ground truth
detected_masks: (optional) A uint8 numpy array of shape [N, height,
width]. If not None, the scores will be computed based on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape [M, height,
width].
Returns:
is_class_correctly_detected_in_image: a numpy integer array of
......@@ -162,8 +169,7 @@ class PerImageEvaluation(object):
groundtruth_masks is not None):
raise ValueError(
'If `detected_masks` is provided, then `groundtruth_masks` should '
'also be provided.'
)
'also be provided.')
is_class_correctly_detected_in_image = np.zeros(
self.num_groundtruth_classes, dtype=int)
......@@ -184,23 +190,25 @@ class PerImageEvaluation(object):
return is_class_correctly_detected_in_image
def _compute_is_class_correctly_detected_in_image(
self, detected_boxes, detected_scores, groundtruth_boxes,
detected_masks=None, groundtruth_masks=None):
def _compute_is_class_correctly_detected_in_image(self,
detected_boxes,
detected_scores,
groundtruth_boxes,
detected_masks=None,
groundtruth_masks=None):
"""Compute CorLoc score for a single class.
Args:
detected_boxes: A numpy array of shape [N, 4] representing detected box
coordinates
coordinates
detected_scores: A 1-d numpy array of length N representing classification
score
score
groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth
box coordinates
detected_masks: (optional) A np.uint8 numpy array of shape
[N, height, width]. If not None, the scores will be computed based
on masks.
groundtruth_masks: (optional) A np.uint8 numpy array of shape
[M, height, width].
box coordinates
detected_masks: (optional) A np.uint8 numpy array of shape [N, height,
width]. If not None, the scores will be computed based on masks.
groundtruth_masks: (optional) A np.uint8 numpy array of shape [M, height,
width].
Returns:
is_class_correctly_detected_in_image: An integer 1 or 0 denoting whether a
......@@ -228,34 +236,38 @@ class PerImageEvaluation(object):
return 1
return 0
def _compute_tp_fp(self, detected_boxes, detected_scores,
detected_class_labels, groundtruth_boxes,
groundtruth_class_labels, groundtruth_is_difficult_list,
def _compute_tp_fp(self,
detected_boxes,
detected_scores,
detected_class_labels,
groundtruth_boxes,
groundtruth_class_labels,
groundtruth_is_difficult_list,
groundtruth_is_group_of_list,
detected_masks=None, groundtruth_masks=None):
detected_masks=None,
groundtruth_masks=None):
"""Labels true/false positives of detections of an image across all classes.
Args:
detected_boxes: A float numpy array of shape [N, 4], representing N
regions of detected object regions.
Each row is of the format [y_min, x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing
the confidence scores of the detected N object instances.
regions of detected object regions. Each row is of the format [y_min,
x_min, y_max, x_max]
detected_scores: A float numpy array of shape [N, 1], representing the
confidence scores of the detected N object instances.
detected_class_labels: A integer numpy array of shape [N, 1], repreneting
the class labels of the detected N object instances.
the class labels of the detected N object instances.
groundtruth_boxes: A float numpy array of shape [M, 4], representing M
regions of object instances in ground truth
regions of object instances in ground truth
groundtruth_class_labels: An integer numpy array of shape [M, 1],
representing M class labels of object instances in ground truth
representing M class labels of object instances in ground truth
groundtruth_is_difficult_list: A boolean numpy array of length M denoting
whether a ground truth box is a difficult instance or not
whether a ground truth box is a difficult instance or not
groundtruth_is_group_of_list: A boolean numpy array of length M denoting
whether a ground truth box has group-of tag
detected_masks: (optional) A np.uint8 numpy array of shape
[N, height, width]. If not None, the scores will be computed based
on masks.
groundtruth_masks: (optional) A np.uint8 numpy array of shape
[M, height, width].
whether a ground truth box has group-of tag
detected_masks: (optional) A np.uint8 numpy array of shape [N, height,
width]. If not None, the scores will be computed based on masks.
groundtruth_masks: (optional) A np.uint8 numpy array of shape [M, height,
width].
Returns:
result_scores: A list of float numpy arrays. Each numpy array is of
......@@ -293,34 +305,33 @@ class PerImageEvaluation(object):
detected_boxes=detected_boxes_at_ith_class,
detected_scores=detected_scores_at_ith_class,
groundtruth_boxes=gt_boxes_at_ith_class,
groundtruth_is_difficult_list=
groundtruth_is_difficult_list_at_ith_class,
groundtruth_is_group_of_list=
groundtruth_is_group_of_list_at_ith_class,
groundtruth_is_difficult_list=groundtruth_is_difficult_list_at_ith_class,
groundtruth_is_group_of_list=groundtruth_is_group_of_list_at_ith_class,
detected_masks=detected_masks_at_ith_class,
groundtruth_masks=gt_masks_at_ith_class)
result_scores.append(scores)
result_tp_fp_labels.append(tp_fp_labels)
return result_scores, result_tp_fp_labels
def _get_overlaps_and_scores_mask_mode(
self, detected_boxes, detected_scores, detected_masks, groundtruth_boxes,
groundtruth_masks, groundtruth_is_group_of_list):
def _get_overlaps_and_scores_mask_mode(self, detected_boxes, detected_scores,
detected_masks, groundtruth_boxes,
groundtruth_masks,
groundtruth_is_group_of_list):
"""Computes overlaps and scores between detected and groudntruth masks.
Args:
detected_boxes: A numpy array of shape [N, 4] representing detected box
coordinates
coordinates
detected_scores: A 1-d numpy array of length N representing classification
score
score
detected_masks: A uint8 numpy array of shape [N, height, width]. If not
None, the scores will be computed based on masks.
None, the scores will be computed based on masks.
groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth
box coordinates
box coordinates
groundtruth_masks: A uint8 numpy array of shape [M, height, width].
groundtruth_is_group_of_list: A boolean numpy array of length M denoting
whether a ground truth box has group-of tag. If a groundtruth box
is group-of box, every detection matching this box is ignored.
whether a ground truth box has group-of tag. If a groundtruth box is
group-of box, every detection matching this box is ignored.
Returns:
iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
......@@ -348,24 +359,21 @@ class PerImageEvaluation(object):
num_boxes = detected_boxlist.num_boxes()
return iou, ioa, scores, num_boxes
def _get_overlaps_and_scores_box_mode(
self,
detected_boxes,
detected_scores,
groundtruth_boxes,
groundtruth_is_group_of_list):
def _get_overlaps_and_scores_box_mode(self, detected_boxes, detected_scores,
groundtruth_boxes,
groundtruth_is_group_of_list):
"""Computes overlaps and scores between detected and groudntruth boxes.
Args:
detected_boxes: A numpy array of shape [N, 4] representing detected box
coordinates
coordinates
detected_scores: A 1-d numpy array of length N representing classification
score
score
groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth
box coordinates
box coordinates
groundtruth_is_group_of_list: A boolean numpy array of length M denoting
whether a ground truth box has group-of tag. If a groundtruth box
is group-of box, every detection matching this box is ignored.
whether a ground truth box has group-of tag. If a groundtruth box is
group-of box, every detection matching this box is ignored.
Returns:
iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If
......@@ -390,31 +398,34 @@ class PerImageEvaluation(object):
num_boxes = detected_boxlist.num_boxes()
return iou, ioa, scores, num_boxes
def _compute_tp_fp_for_single_class(
self, detected_boxes, detected_scores, groundtruth_boxes,
groundtruth_is_difficult_list, groundtruth_is_group_of_list,
detected_masks=None, groundtruth_masks=None):
def _compute_tp_fp_for_single_class(self,
detected_boxes,
detected_scores,
groundtruth_boxes,
groundtruth_is_difficult_list,
groundtruth_is_group_of_list,
detected_masks=None,
groundtruth_masks=None):
"""Labels boxes detected with the same class from the same image as tp/fp.
Args:
detected_boxes: A numpy array of shape [N, 4] representing detected box
coordinates
coordinates
detected_scores: A 1-d numpy array of length N representing classification
score
score
groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth
box coordinates
box coordinates
groundtruth_is_difficult_list: A boolean numpy array of length M denoting
whether a ground truth box is a difficult instance or not. If a
groundtruth box is difficult, every detection matching this box
is ignored.
whether a ground truth box is a difficult instance or not. If a
groundtruth box is difficult, every detection matching this box is
ignored.
groundtruth_is_group_of_list: A boolean numpy array of length M denoting
whether a ground truth box has group-of tag. If a groundtruth box
is group-of box, every detection matching this box is ignored.
detected_masks: (optional) A uint8 numpy array of shape
[N, height, width]. If not None, the scores will be computed based
on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape
[M, height, width].
whether a ground truth box has group-of tag. If a groundtruth box is
group-of box, every detection matching this box is ignored.
detected_masks: (optional) A uint8 numpy array of shape [N, height,
width]. If not None, the scores will be computed based on masks.
groundtruth_masks: (optional) A uint8 numpy array of shape [M, height,
width].
Returns:
Two arrays of the same size, containing all boxes that were evaluated as
......@@ -432,16 +443,39 @@ class PerImageEvaluation(object):
if detected_masks is not None and groundtruth_masks is not None:
mask_mode = True
iou = np.ndarray([0, 0])
ioa = np.ndarray([0, 0])
iou_mask = np.ndarray([0, 0])
ioa_mask = np.ndarray([0, 0])
if mask_mode:
(iou, ioa, scores,
# For Instance Segmentation Evaluation on Open Images V5, not all boxed
# instances have corresponding segmentation annotations. Those boxes that
# dont have segmentation annotations are represented as empty masks in
# groundtruth_masks nd array.
mask_presence_indicator = (np.sum(groundtruth_masks, axis=(1, 2)) > 0)
(iou_mask, ioa_mask, scores,
num_detected_boxes) = self._get_overlaps_and_scores_mask_mode(
detected_boxes=detected_boxes,
detected_scores=detected_scores,
detected_masks=detected_masks,
groundtruth_boxes=groundtruth_boxes,
groundtruth_masks=groundtruth_masks,
groundtruth_is_group_of_list=groundtruth_is_group_of_list)
groundtruth_boxes=groundtruth_boxes[mask_presence_indicator, :],
groundtruth_masks=groundtruth_masks[mask_presence_indicator, :],
groundtruth_is_group_of_list=groundtruth_is_group_of_list[
mask_presence_indicator])
if sum(mask_presence_indicator) < len(mask_presence_indicator):
# Not all masks are present - some masks are empty
(iou, ioa, _,
num_detected_boxes) = self._get_overlaps_and_scores_box_mode(
detected_boxes=detected_boxes,
detected_scores=detected_scores,
groundtruth_boxes=groundtruth_boxes[~mask_presence_indicator, :],
groundtruth_is_group_of_list=groundtruth_is_group_of_list[
~mask_presence_indicator])
num_detected_boxes = detected_boxes.shape[0]
else:
mask_presence_indicator = np.zeros(
groundtruth_is_group_of_list.shape, dtype=bool)
(iou, ioa, scores,
num_detected_boxes) = self._get_overlaps_and_scores_box_mode(
detected_boxes=detected_boxes,
......@@ -453,55 +487,135 @@ class PerImageEvaluation(object):
return scores, np.zeros(num_detected_boxes, dtype=bool)
tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool)
is_matched_to_difficult_box = np.zeros(num_detected_boxes, dtype=bool)
is_matched_to_group_of_box = np.zeros(num_detected_boxes, dtype=bool)
# The evaluation is done in two stages:
# 1. All detections are matched to non group-of boxes; true positives are
# determined and detections matched to difficult boxes are ignored.
# 2. Detections that are determined as false positives are matched against
# group-of boxes and scored with weight w per ground truth box is
# matched.
# Tp-fp evaluation for non-group of boxes (if any).
if iou.shape[1] > 0:
groundtruth_nongroup_of_is_difficult_list = groundtruth_is_difficult_list[
~groundtruth_is_group_of_list]
is_matched_to_box = np.zeros(num_detected_boxes, dtype=bool)
is_matched_to_difficult = np.zeros(num_detected_boxes, dtype=bool)
is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool)
def compute_match_iou(iou, groundtruth_nongroup_of_is_difficult_list,
is_box):
"""Computes TP/FP for non group-of box matching.
The function updates the following local variables:
tp_fp_labels - if a box is matched to group-of
is_matched_to_difficult - the detections that were processed at this are
matched to difficult box.
is_matched_to_box - the detections that were processed at this stage are
marked as is_box.
Args:
iou: intersection-over-union matrix [num_gt_boxes]x[num_det_boxes].
groundtruth_nongroup_of_is_difficult_list: boolean that specifies if gt
box is difficult.
is_box: boolean that specifies if currently boxes or masks are
processed.
"""
max_overlap_gt_ids = np.argmax(iou, axis=1)
is_gt_box_detected = np.zeros(iou.shape[1], dtype=bool)
is_gt_detected = np.zeros(iou.shape[1], dtype=bool)
for i in range(num_detected_boxes):
gt_id = max_overlap_gt_ids[i]
if iou[i, gt_id] >= self.matching_iou_threshold:
is_evaluatable = (not tp_fp_labels[i] and
not is_matched_to_difficult[i] and
iou[i, gt_id] >= self.matching_iou_threshold and
not is_matched_to_group_of[i])
if is_evaluatable:
if not groundtruth_nongroup_of_is_difficult_list[gt_id]:
if not is_gt_box_detected[gt_id]:
if not is_gt_detected[gt_id]:
tp_fp_labels[i] = True
is_gt_box_detected[gt_id] = True
is_gt_detected[gt_id] = True
is_matched_to_box[i] = is_box
else:
is_matched_to_difficult_box[i] = True
scores_group_of = np.zeros(ioa.shape[1], dtype=float)
tp_fp_labels_group_of = self.group_of_weight * np.ones(
ioa.shape[1], dtype=float)
# Tp-fp evaluation for group of boxes.
if ioa.shape[1] > 0:
is_matched_to_difficult[i] = True
def compute_match_ioa(ioa, is_box):
"""Computes TP/FP for group-of box matching.
The function updates the following local variables:
is_matched_to_group_of - if a box is matched to group-of
is_matched_to_box - the detections that were processed at this stage are
marked as is_box.
Args:
ioa: intersection-over-area matrix [num_gt_boxes]x[num_det_boxes].
is_box: boolean that specifies if currently boxes or masks are
processed.
Returns:
scores_group_of: of detections matched to group-of boxes
[num_groupof_matched].
tp_fp_labels_group_of: boolean array of size [num_groupof_matched], all
values are True.
"""
scores_group_of = np.zeros(ioa.shape[1], dtype=float)
tp_fp_labels_group_of = self.group_of_weight * np.ones(
ioa.shape[1], dtype=float)
max_overlap_group_of_gt_ids = np.argmax(ioa, axis=1)
for i in range(num_detected_boxes):
gt_id = max_overlap_group_of_gt_ids[i]
if (not tp_fp_labels[i] and not is_matched_to_difficult_box[i] and
ioa[i, gt_id] >= self.matching_iou_threshold):
is_matched_to_group_of_box[i] = True
is_evaluatable = (not tp_fp_labels[i] and
not is_matched_to_difficult[i] and
ioa[i, gt_id] >= self.matching_iou_threshold and
not is_matched_to_group_of[i])
if is_evaluatable:
is_matched_to_group_of[i] = True
is_matched_to_box[i] = is_box
scores_group_of[gt_id] = max(scores_group_of[gt_id], scores[i])
selector = np.where((scores_group_of > 0) & (tp_fp_labels_group_of > 0))
scores_group_of = scores_group_of[selector]
tp_fp_labels_group_of = tp_fp_labels_group_of[selector]
return np.concatenate(
(scores[~is_matched_to_difficult_box
& ~is_matched_to_group_of_box],
scores_group_of)), np.concatenate(
(tp_fp_labels[~is_matched_to_difficult_box
& ~is_matched_to_group_of_box].astype(float),
tp_fp_labels_group_of))
return scores_group_of, tp_fp_labels_group_of
# The evaluation is done in two stages:
# 1. Evaluate all objects that actually have instance level masks.
# 2. Evaluate all objects that are not already evaluated as boxes.
if iou_mask.shape[1] > 0:
groundtruth_is_difficult_mask_list = groundtruth_is_difficult_list[
mask_presence_indicator]
groundtruth_is_group_of_mask_list = groundtruth_is_group_of_list[
mask_presence_indicator]
compute_match_iou(
iou_mask,
groundtruth_is_difficult_mask_list[
~groundtruth_is_group_of_mask_list],
is_box=False)
scores_mask_group_of = np.ndarray([0], dtype=float)
tp_fp_labels_mask_group_of = np.ndarray([0], dtype=float)
if ioa_mask.shape[1] > 0:
scores_mask_group_of, tp_fp_labels_mask_group_of = compute_match_ioa(
ioa_mask, is_box=False)
# Tp-fp evaluation for non-group of boxes (if any).
if iou.shape[1] > 0:
groundtruth_is_difficult_box_list = groundtruth_is_difficult_list[
~mask_presence_indicator]
groundtruth_is_group_of_box_list = groundtruth_is_group_of_list[
~mask_presence_indicator]
compute_match_iou(
iou,
groundtruth_is_difficult_box_list[~groundtruth_is_group_of_box_list],
is_box=True)
scores_box_group_of = np.ndarray([0], dtype=float)
tp_fp_labels_box_group_of = np.ndarray([0], dtype=float)
if ioa.shape[1] > 0:
scores_box_group_of, tp_fp_labels_box_group_of = compute_match_ioa(
ioa, is_box=True)
if mask_mode:
# Note: here crowds are treated as ignore regions.
valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of
& ~is_matched_to_box)
return np.concatenate(
(scores[valid_entries], scores_mask_group_of)), np.concatenate(
(tp_fp_labels[valid_entries].astype(float),
tp_fp_labels_mask_group_of))
else:
valid_entries = (~is_matched_to_difficult & ~is_matched_to_group_of)
return np.concatenate(
(scores[valid_entries], scores_box_group_of)), np.concatenate(
(tp_fp_labels[valid_entries].astype(float),
tp_fp_labels_box_group_of))
def _get_ith_class_arrays(self, detected_boxes, detected_scores,
detected_masks, detected_class_labels,
......@@ -549,8 +663,11 @@ class PerImageEvaluation(object):
detected_boxes_at_ith_class, detected_scores_at_ith_class,
detected_masks_at_ith_class)
def _remove_invalid_boxes(self, detected_boxes, detected_scores,
detected_class_labels, detected_masks=None):
def _remove_invalid_boxes(self,
detected_boxes,
detected_scores,
detected_class_labels,
detected_masks=None):
"""Removes entries with invalid boxes.
A box is invalid if either its xmax is smaller than its xmin, or its ymax
......
......@@ -472,6 +472,158 @@ class SingleClassTpFpNoDifficultBoxesTest(tf.test.TestCase):
self.assertTrue(np.allclose(expected_tp_fp_labels, tp_fp_labels))
class SingleClassTpFpEmptyMaskAndBoxesTest(tf.test.TestCase):
def setUp(self):
num_groundtruth_classes = 1
matching_iou_threshold_iou = 0.5
nms_iou_threshold = 1.0
nms_max_output_boxes = 10000
self.eval = per_image_evaluation.PerImageEvaluation(
num_groundtruth_classes, matching_iou_threshold_iou, nms_iou_threshold,
nms_max_output_boxes)
def test_mask_tp_and_ignore(self):
# GT: one box with mask, one without
# Det: One mask matches gt1, one matches box gt2 and is ignored
groundtruth_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 2]], dtype=float)
groundtruth_mask_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_mask_1 = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0, groundtruth_mask_1],
axis=0)
groundtruth_groundtruth_is_difficult_list = np.zeros(2, dtype=bool)
groundtruth_groundtruth_is_group_of_list = np.array([False, False],
dtype=bool)
detected_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 2]], dtype=float)
detected_scores = np.array([0.6, 0.8], dtype=float)
detected_masks_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks_1 = np.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detected_masks_0, detected_masks_1], axis=0)
scores, tp_fp_labels = self.eval._compute_tp_fp_for_single_class(
detected_boxes, detected_scores, groundtruth_boxes,
groundtruth_groundtruth_is_difficult_list,
groundtruth_groundtruth_is_group_of_list, detected_masks,
groundtruth_masks)
expected_scores = np.array([0.6], dtype=float)
expected_tp_fp_labels = np.array([True], dtype=bool)
self.assertTrue(np.allclose(expected_scores, scores))
self.assertTrue(np.allclose(expected_tp_fp_labels, tp_fp_labels))
def test_mask_one_tp_one_fp(self):
# GT: one box with mask, one without
# Det: one mask matches gt1, one is fp (box does not match)
groundtruth_boxes = np.array([[0, 0, 2, 3], [2, 2, 4, 4]], dtype=float)
groundtruth_mask_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_mask_1 = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0, groundtruth_mask_1],
axis=0)
groundtruth_groundtruth_is_difficult_list = np.zeros(2, dtype=bool)
groundtruth_groundtruth_is_group_of_list = np.array([False, False],
dtype=bool)
detected_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 2]], dtype=float)
detected_scores = np.array([0.6, 0.8], dtype=float)
detected_masks_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks_1 = np.array([[1, 0, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detected_masks_0, detected_masks_1], axis=0)
scores, tp_fp_labels = self.eval._compute_tp_fp_for_single_class(
detected_boxes,
detected_scores,
groundtruth_boxes,
groundtruth_groundtruth_is_difficult_list,
groundtruth_groundtruth_is_group_of_list,
detected_masks=detected_masks,
groundtruth_masks=groundtruth_masks)
expected_scores = np.array([0.8, 0.6], dtype=float)
expected_tp_fp_labels = np.array([False, True], dtype=bool)
self.assertTrue(np.allclose(expected_scores, scores))
self.assertTrue(np.allclose(expected_tp_fp_labels, tp_fp_labels))
def test_two_mask_one_gt_one_ignore(self):
# GT: one box with mask, one without.
# Det: two mask matches same gt, one is tp, one is passed down to box match
# and ignored.
groundtruth_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 3]], dtype=float)
groundtruth_mask_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_mask_1 = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0, groundtruth_mask_1],
axis=0)
groundtruth_groundtruth_is_difficult_list = np.zeros(2, dtype=bool)
groundtruth_groundtruth_is_group_of_list = np.array([False, False],
dtype=bool)
detected_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 3]], dtype=float)
detected_scores = np.array([0.6, 0.8], dtype=float)
detected_masks_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks_1 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detected_masks_0, detected_masks_1], axis=0)
scores, tp_fp_labels = self.eval._compute_tp_fp_for_single_class(
detected_boxes,
detected_scores,
groundtruth_boxes,
groundtruth_groundtruth_is_difficult_list,
groundtruth_groundtruth_is_group_of_list,
detected_masks=detected_masks,
groundtruth_masks=groundtruth_masks)
expected_scores = np.array([0.8], dtype=float)
expected_tp_fp_labels = np.array([True], dtype=bool)
self.assertTrue(np.allclose(expected_scores, scores))
self.assertTrue(np.allclose(expected_tp_fp_labels, tp_fp_labels))
def test_two_mask_one_gt_one_fp(self):
# GT: one box with mask, one without.
# Det: two mask matches same gt, one is tp, one is passed down to box match
# and is fp.
groundtruth_boxes = np.array([[0, 0, 2, 3], [2, 3, 4, 6]], dtype=float)
groundtruth_mask_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_mask_1 = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=np.uint8)
groundtruth_masks = np.stack([groundtruth_mask_0, groundtruth_mask_1],
axis=0)
groundtruth_groundtruth_is_difficult_list = np.zeros(2, dtype=bool)
groundtruth_groundtruth_is_group_of_list = np.array([False, False],
dtype=bool)
detected_boxes = np.array([[0, 0, 2, 3], [0, 0, 2, 3]], dtype=float)
detected_scores = np.array([0.6, 0.8], dtype=float)
detected_masks_0 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks_1 = np.array([[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0]],
dtype=np.uint8)
detected_masks = np.stack([detected_masks_0, detected_masks_1], axis=0)
scores, tp_fp_labels = self.eval._compute_tp_fp_for_single_class(
detected_boxes,
detected_scores,
groundtruth_boxes,
groundtruth_groundtruth_is_difficult_list,
groundtruth_groundtruth_is_group_of_list,
detected_masks=detected_masks,
groundtruth_masks=groundtruth_masks)
expected_scores = np.array([0.8, 0.6], dtype=float)
expected_tp_fp_labels = np.array([True, False], dtype=bool)
self.assertTrue(np.allclose(expected_scores, scores))
self.assertTrue(np.allclose(expected_tp_fp_labels, tp_fp_labels))
class MultiClassesTpFpTest(tf.test.TestCase):
def test_tp_fp(self):
......
......@@ -20,6 +20,9 @@ import tensorflow as tf
from object_detection.utils import static_shape
get_dim_as_int = static_shape.get_dim_as_int
def _is_tensor(t):
"""Returns a boolean indicating whether the input is a tensor.
......@@ -365,3 +368,95 @@ def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1):
tf.less_equal(box_maximum, maximum_normalized_coordinate),
tf.greater_equal(box_minimum, 0)),
[boxes])
def flatten_dimensions(inputs, first, last):
"""Flattens `K-d` tensor along [first, last) dimensions.
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
[D0, D1, ..., D(first) * D(first+1) * ... * D(last-1), D(last), ..., D(K-1)].
Example:
`inputs` is a tensor with initial shape [10, 5, 20, 20, 3].
new_tensor = flatten_dimensions(inputs, last=4, first=2)
new_tensor.shape -> [10, 100, 20, 3].
Args:
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
first: first value for the range of dimensions to flatten.
last: last value for the range of dimensions to flatten. Note that the last
dimension itself is excluded.
Returns:
a tensor with shape
[D0, D1, ..., D(first) * D(first + 1) * ... * D(last - 1), D(last), ...,
D(K-1)].
Raises:
ValueError: if first and last arguments are incorrect.
"""
if first >= inputs.shape.ndims or last > inputs.shape.ndims:
raise ValueError('`first` and `last` must be less than inputs.shape.ndims. '
'found {} and {} respectively while ndims is {}'.format(
first, last, inputs.shape.ndims))
shape = combined_static_and_dynamic_shape(inputs)
flattened_dim_prod = tf.reduce_prod(shape[first:last],
keepdims=True)
new_shape = tf.concat([shape[:first], flattened_dim_prod,
shape[last:]], axis=0)
return tf.reshape(inputs, new_shape)
def flatten_first_n_dimensions(inputs, n):
"""Flattens `K-d` tensor along first n dimension to be a `(K-n+1)-d` tensor.
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
[D0 * D1 * ... * D(n-1), D(n), ... D(K-1)].
Example:
`inputs` is a tensor with initial shape [10, 5, 20, 20, 3].
new_tensor = flatten_first_n_dimensions(inputs, 2)
new_tensor.shape -> [50, 20, 20, 3].
Args:
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
n: The number of dimensions to flatten.
Returns:
a tensor with shape [D0 * D1 * ... * D(n-1), D(n), ... D(K-1)].
"""
return flatten_dimensions(inputs, first=0, last=n)
def expand_first_dimension(inputs, dims):
"""Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor.
Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
[dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].
Example:
`inputs` is a tensor with shape [50, 20, 20, 3].
new_tensor = expand_first_dimension(inputs, [10, 5]).
new_tensor.shape -> [10, 5, 20, 20, 3].
Args:
inputs: a tensor with shape [D0, D1, ..., D(K-1)].
dims: List with new dimensions to expand first axis into. The length of
`dims` is typically 2 or larger.
Returns:
a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].
"""
inputs_shape = combined_static_and_dynamic_shape(inputs)
expanded_shape = tf.stack(dims + inputs_shape[1:])
# Verify that it is possible to expand the first axis of inputs.
assert_op = tf.assert_equal(
inputs_shape[0], tf.reduce_prod(tf.stack(dims)),
message=('First dimension of `inputs` cannot be expanded into provided '
'`dims`'))
with tf.control_dependencies([assert_op]):
inputs_reshaped = tf.reshape(inputs, expanded_shape)
return inputs_reshaped
......@@ -333,5 +333,79 @@ class AssertShapeEqualTest(tf.test.TestCase):
tensor_b: np.zeros([5])})
class FlattenExpandDimensionTest(tf.test.TestCase):
def test_flatten_given_dims(self):
inputs = tf.random_uniform([5, 2, 10, 10, 3])
actual_flattened = shape_utils.flatten_dimensions(inputs, first=1, last=3)
expected_flattened = tf.reshape(inputs, [5, 20, 10, 3])
with self.test_session() as sess:
(actual_flattened_np,
expected_flattened_np) = sess.run([actual_flattened, expected_flattened])
self.assertAllClose(expected_flattened_np, actual_flattened_np)
def test_raises_value_error_incorrect_dimensions(self):
inputs = tf.random_uniform([5, 2, 10, 10, 3])
with self.assertRaises(ValueError):
shape_utils.flatten_dimensions(inputs, first=0, last=6)
def test_flatten_first_two_dimensions(self):
inputs = tf.constant(
[
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]
], dtype=tf.int32)
flattened_tensor = shape_utils.flatten_first_n_dimensions(
inputs, 2)
with self.test_session() as sess:
flattened_tensor_out = sess.run(flattened_tensor)
expected_output = [[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12]]
self.assertAllEqual(expected_output, flattened_tensor_out)
def test_expand_first_dimension(self):
inputs = tf.constant(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12]
], dtype=tf.int32)
dims = [3, 2]
expanded_tensor = shape_utils.expand_first_dimension(
inputs, dims)
with self.test_session() as sess:
expanded_tensor_out = sess.run(expanded_tensor)
expected_output = [
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]]
self.assertAllEqual(expected_output, expanded_tensor_out)
def test_expand_first_dimension_with_incompatible_dims(self):
inputs_default = tf.constant(
[
[[1, 2]],
[[3, 4]],
[[5, 6]],
], dtype=tf.int32)
inputs = tf.placeholder_with_default(inputs_default, [None, 1, 2])
dims = [3, 2]
expanded_tensor = shape_utils.expand_first_dimension(
inputs, dims)
with self.test_session() as sess:
with self.assertRaises(tf.errors.InvalidArgumentError):
sess.run(expanded_tensor)
if __name__ == '__main__':
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Spatial transformation ops like RoIAlign, CropAndResize."""
import tensorflow as tf
def _coordinate_vector_1d(start, end, size, align_endpoints):
"""Generates uniformly spaced coordinate vector.
Args:
start: A float tensor of shape [batch, num_boxes] indicating start values.
end: A float tensor of shape [batch, num_boxes] indicating end values.
size: Number of points in coordinate vector.
align_endpoints: Whether to align first and last points exactly to
endpoints.
Returns:
A 3D float tensor of shape [batch, num_boxes, size] containing grid
coordinates.
"""
start = tf.expand_dims(start, -1)
end = tf.expand_dims(end, -1)
length = tf.cast(end - start, dtype=tf.float32)
if align_endpoints:
relative_grid_spacing = tf.linspace(0.0, 1.0, size)
offset = 0 if size > 1 else length / 2
else:
relative_grid_spacing = tf.linspace(0.0, 1.0, size + 1)[:-1]
offset = length / (2 * size)
relative_grid_spacing = tf.reshape(relative_grid_spacing, [1, 1, size])
absolute_grid = start + offset + relative_grid_spacing * length
return absolute_grid
def box_grid_coordinate_vectors(boxes, size_y, size_x, align_corners=False):
"""Generates coordinate vectors for a `size x size` grid in boxes.
Each box is subdivided uniformly into a grid consisting of size x size
rectangular cells. This function returns coordinate vectors describing
the center of each cell.
If `align_corners` is true, grid points are uniformly spread such that the
corner points on the grid exactly overlap corners of the boxes.
Note that output coordinates are expressed in the same coordinate frame as
input boxes.
Args:
boxes: A float tensor of shape [batch, num_boxes, 4] containing boxes of the
form [ymin, xmin, ymax, xmax].
size_y: Size of the grid in y axis.
size_x: Size of the grid in x axis.
align_corners: Whether to align the corner grid points exactly with box
corners.
Returns:
box_grid_y: A float tensor of shape [batch, num_boxes, size_y] containing y
coordinates for grid points.
box_grid_x: A float tensor of shape [batch, num_boxes, size_x] containing x
coordinates for grid points.
"""
ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=-1)
box_grid_y = _coordinate_vector_1d(ymin, ymax, size_y, align_corners)
box_grid_x = _coordinate_vector_1d(xmin, xmax, size_x, align_corners)
return box_grid_y, box_grid_x
def feature_grid_coordinate_vectors(box_grid_y, box_grid_x):
"""Returns feature grid point coordinate vectors for bilinear interpolation.
Box grid is specified in absolute coordinate system with origin at left top
(0, 0). The returned coordinate vectors contain 0-based feature point indices.
This function snaps each point in the box grid to nearest 4 points on the
feature map.
In this function we also follow the convention of treating feature pixels as
point objects with no spatial extent.
Args:
box_grid_y: A float tensor of shape [batch, num_boxes, size] containing y
coordinate vector of the box grid.
box_grid_x: A float tensor of shape [batch, num_boxes, size] containing x
coordinate vector of the box grid.
Returns:
feature_grid_y0: An int32 tensor of shape [batch, num_boxes, size]
containing y coordinate vector for the top neighbors.
feature_grid_x0: A int32 tensor of shape [batch, num_boxes, size]
containing x coordinate vector for the left neighbors.
feature_grid_y1: A int32 tensor of shape [batch, num_boxes, size]
containing y coordinate vector for the bottom neighbors.
feature_grid_x1: A int32 tensor of shape [batch, num_boxes, size]
containing x coordinate vector for the right neighbors.
"""
feature_grid_y0 = tf.floor(box_grid_y)
feature_grid_x0 = tf.floor(box_grid_x)
feature_grid_y1 = tf.floor(box_grid_y + 1)
feature_grid_x1 = tf.floor(box_grid_x + 1)
feature_grid_y0 = tf.cast(feature_grid_y0, dtype=tf.int32)
feature_grid_y1 = tf.cast(feature_grid_y1, dtype=tf.int32)
feature_grid_x0 = tf.cast(feature_grid_x0, dtype=tf.int32)
feature_grid_x1 = tf.cast(feature_grid_x1, dtype=tf.int32)
return (feature_grid_y0, feature_grid_x0, feature_grid_y1, feature_grid_x1)
def _valid_indicator(feature_grid_y, feature_grid_x, true_feature_shapes):
"""Computes a indicator vector for valid indices.
Computes an indicator vector which is true for points on feature map and
false for points off feature map.
Args:
feature_grid_y: An int32 tensor of shape [batch, num_boxes, size_y]
containing y coordinate vector.
feature_grid_x: An int32 tensor of shape [batch, num_boxes, size_x]
containing x coordinate vector.
true_feature_shapes: A int32 tensor of shape [batch, num_boxes, 2]
containing valid height and width of feature maps. Feature maps are
assumed to be aligned to the left top corner.
Returns:
indices: A 1D bool tensor indicating valid feature indices.
"""
height = tf.cast(true_feature_shapes[:, :, 0:1], dtype=feature_grid_y.dtype)
width = tf.cast(true_feature_shapes[:, :, 1:2], dtype=feature_grid_x.dtype)
valid_indicator = tf.logical_and(
tf.expand_dims(
tf.logical_and(feature_grid_y >= 0, tf.less(feature_grid_y, height)),
3),
tf.expand_dims(
tf.logical_and(feature_grid_x >= 0, tf.less(feature_grid_x, width)),
2))
return tf.reshape(valid_indicator, [-1])
def ravel_indices(feature_grid_y, feature_grid_x, num_levels, height, width,
box_levels):
"""Returns grid indices in a flattened feature map of shape [-1, channels].
The returned 1-D array can be used to gather feature grid points from a
feature map that has been flattened from [batch, num_levels, max_height,
max_width, channels] to [batch * num_levels * max_height * max_width,
channels].
Args:
feature_grid_y: An int32 tensor of shape [batch, num_boxes, size_y]
containing y coordinate vector.
feature_grid_x: An int32 tensor of shape [batch, num_boxes, size_x]
containing x coordinate vector.
num_levels: Number of feature levels.
height: An integer indicating the padded height of feature maps.
width: An integer indicating the padded width of feature maps.
box_levels: An int32 tensor of shape [batch, num_boxes] indicating
feature level assigned to each box.
Returns:
indices: A 1D int32 tensor containing feature point indices in a flattened
feature grid.
"""
assert feature_grid_y.shape[0] == feature_grid_x.shape[0]
assert feature_grid_y.shape[1] == feature_grid_x.shape[1]
num_boxes = feature_grid_y.shape[1].value
batch_size = feature_grid_y.shape[0].value
size_y = feature_grid_y.shape[2]
size_x = feature_grid_x.shape[2]
height_dim_offset = width
level_dim_offset = height * height_dim_offset
batch_dim_offset = num_levels * level_dim_offset
batch_dim_indices = (
tf.reshape(
tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]) *
tf.ones([1, num_boxes, size_y, size_x], dtype=tf.int32))
box_level_indices = (
tf.reshape(box_levels * level_dim_offset, [batch_size, num_boxes, 1, 1]) *
tf.ones([1, 1, size_y, size_x], dtype=tf.int32))
height_indices = (
tf.reshape(feature_grid_y * height_dim_offset,
[batch_size, num_boxes, size_y, 1]) *
tf.ones([1, 1, 1, size_x], dtype=tf.int32))
width_indices = (
tf.reshape(feature_grid_x, [batch_size, num_boxes, 1, size_x])
* tf.ones([1, 1, size_y, 1], dtype=tf.int32))
indices = (
batch_dim_indices + box_level_indices + height_indices + width_indices)
flattened_indices = tf.reshape(indices, [-1])
return flattened_indices
def pad_to_max_size(features):
"""Pads features to max height and max width and stacks them up.
Args:
features: A list of num_levels 4D float tensors of shape [batch, height_i,
width_i, channels] containing feature maps.
Returns:
stacked_features: A 5D float tensor of shape [batch, num_levels, max_height,
max_width, channels] containing stacked features.
true_feature_shapes: A 2D int32 tensor of shape [num_levels, 2] containing
height and width of the feature maps before padding.
"""
heights = [feature.shape[1].value for feature in features]
widths = [feature.shape[2].value for feature in features]
max_height = max(heights)
max_width = max(widths)
features_all = [
tf.image.pad_to_bounding_box(feature, 0, 0, max_height,
max_width) for feature in features
]
features_all = tf.stack(features_all, axis=1)
true_feature_shapes = tf.stack([feature.shape[1:3] for feature in features])
return features_all, true_feature_shapes
def _gather_valid_indices(tensor, indices, padding_value=0.0):
"""Gather values for valid indices.
TODO(rathodv): We can't use ops.gather_with_padding_values due to cyclic
dependency. Start using it after migrating all users of spatial ops to import
this module directly rather than util/ops.py
Args:
tensor: A tensor to gather valid values from.
indices: A 1-D int32 tensor containing indices along axis 0 of `tensor`.
Invalid indices must be marked with -1.
padding_value: Value to return for invalid indices.
Returns:
A tensor sliced based on indices. For indices that are equal to -1, returns
rows of padding value.
"""
padded_tensor = tf.concat(
[
padding_value *
tf.ones([1, tensor.shape[-1].value], dtype=tensor.dtype), tensor
],
axis=0,
)
# tf.concat gradient op uses tf.where(condition) (which is not
# supported on TPU) when the inputs to it are tf.IndexedSlices instead of
# tf.Tensor. Since gradient op for tf.gather returns tf.IndexedSlices,
# we add a dummy op inbetween tf.concat and tf.gather to ensure tf.concat
# gradient function gets tf.Tensor inputs and not tf.IndexedSlices.
padded_tensor *= 1.0
return tf.gather(padded_tensor, indices + 1)
def multilevel_roi_align(features, boxes, box_levels, output_size,
num_samples_per_cell_y=1, num_samples_per_cell_x=1,
align_corners=False, extrapolation_value=0.0,
scope=None):
"""Applies RoI Align op and returns feature for boxes.
Given multiple features maps indexed by different levels, and a set of boxes
where each box is mapped to a certain level, this function selectively crops
and resizes boxes from the corresponding feature maps.
We follow the RoI Align technique in https://arxiv.org/pdf/1703.06870.pdf
figure 3. Specifically, each box is subdivided uniformly into a grid
consisting of output_size[0] x output_size[1] rectangular cells. Within each
cell we select `num_points` points uniformly and compute feature values using
bilinear interpolation. Finally, we average pool the interpolated values in
each cell to obtain a [output_size[0], output_size[1], channels] feature.
If `align_corners` is true, sampling points are uniformly spread such that
corner points exactly overlap corners of the boxes.
In this function we also follow the convention of treating feature pixels as
point objects with no spatial extent.
Args:
features: A list of 4D float tensors of shape [batch_size, max_height,
max_width, channels] containing features.
boxes: A 3D float tensor of shape [batch_size, num_boxes, 4] containing
boxes of the form [ymin, xmin, ymax, xmax] in normalized coordinates.
box_levels: A 3D int32 tensor of shape [batch_size, num_boxes, 1]
representing the feature level index for each box.
output_size: An list of two integers [size_y, size_x] indicating the output
feature size for each box.
num_samples_per_cell_y: Number of grid points to sample along y axis in each
cell.
num_samples_per_cell_x: Number of grid points to sample along x axis in each
cell.
align_corners: Whether to align the corner grid points exactly with box
corners.
extrapolation_value: a float value to use for extrapolation.
scope: Scope name to use for this op.
Returns:
A 5D float tensor of shape [batch_size, num_boxes, output_size[0],
output_size[1], channels] representing the cropped features.
"""
with tf.name_scope(scope, 'MultiLevelRoIAlign'):
features, true_feature_shapes = pad_to_max_size(features)
(batch_size, num_levels, max_feature_height, max_feature_width,
num_filters) = features.get_shape().as_list()
_, num_boxes, _ = boxes.get_shape().as_list()
# Convert boxes to absolute co-ordinates.
true_feature_shapes = tf.cast(true_feature_shapes, dtype=boxes.dtype)
true_feature_shapes = tf.gather(true_feature_shapes, box_levels)
boxes *= tf.concat([true_feature_shapes - 1] * 2, axis=-1)
size_y = output_size[0] * num_samples_per_cell_y
size_x = output_size[1] * num_samples_per_cell_x
box_grid_y, box_grid_x = box_grid_coordinate_vectors(
boxes, size_y=size_y, size_x=size_x, align_corners=align_corners)
(feature_grid_y0, feature_grid_x0, feature_grid_y1,
feature_grid_x1) = feature_grid_coordinate_vectors(box_grid_y, box_grid_x)
feature_grid_y = tf.reshape(
tf.stack([feature_grid_y0, feature_grid_y1], axis=3),
[batch_size, num_boxes, -1])
feature_grid_x = tf.reshape(
tf.stack([feature_grid_x0, feature_grid_x1], axis=3),
[batch_size, num_boxes, -1])
feature_coordinates = ravel_indices(feature_grid_y, feature_grid_x,
num_levels, max_feature_height,
max_feature_width, box_levels)
valid_indices = _valid_indicator(feature_grid_y, feature_grid_x,
true_feature_shapes)
feature_coordinates = tf.where(valid_indices, feature_coordinates,
-1 * tf.ones_like(feature_coordinates))
flattened_features = tf.reshape(features, [-1, num_filters])
flattened_feature_values = _gather_valid_indices(flattened_features,
feature_coordinates,
extrapolation_value)
features_per_box = tf.reshape(
flattened_feature_values,
[batch_size, num_boxes, size_y * 2, size_x * 2, num_filters])
# Cast tensors into dtype of features.
box_grid_y = tf.cast(box_grid_y, dtype=features_per_box.dtype)
box_grid_x = tf.cast(box_grid_x, dtype=features_per_box.dtype)
feature_grid_y0 = tf.cast(feature_grid_y0, dtype=features_per_box.dtype)
feature_grid_x0 = tf.cast(feature_grid_x0, dtype=features_per_box.dtype)
# RoI Align operation is a bilinear interpolation of four
# neighboring feature points f0, f1, f2, and f3 onto point y, x given by
# f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
# [f10, f11]]
#
# Unrolling the matrix multiplies gives us:
# f(y, x) = (hy * hx) f00 + (hy * lx) f01 + (ly * hx) f10 + (lx * ly) f11
# f(y, x) = w00 * f00 + w01 * f01 + w10 * f10 + w11 * f11
#
# This can be computed by applying pointwise multiplication and sum_pool in
# a 2x2 window.
ly = box_grid_y - feature_grid_y0
lx = box_grid_x - feature_grid_x0
hy = 1.0 - ly
hx = 1.0 - lx
kernel_y = tf.reshape(
tf.stack([hy, ly], axis=3), [batch_size, num_boxes, size_y * 2, 1])
kernel_x = tf.reshape(
tf.stack([hx, lx], axis=3), [batch_size, num_boxes, 1, size_x * 2])
# Multiplier 4 is to make tf.nn.avg_pool behave like sum_pool.
interpolation_kernel = kernel_y * kernel_x * 4
# Interpolate the gathered features with computed interpolation kernels.
features_per_box *= tf.expand_dims(interpolation_kernel, axis=4),
features_per_box = tf.reshape(
features_per_box,
[batch_size * num_boxes, size_y * 2, size_x * 2, num_filters])
# This combines the two pooling operations - sum_pool to perform bilinear
# interpolation and avg_pool to pool the values in each bin.
features_per_box = tf.nn.avg_pool(
features_per_box,
[1, num_samples_per_cell_y * 2, num_samples_per_cell_x * 2, 1],
[1, num_samples_per_cell_y * 2, num_samples_per_cell_x * 2, 1], 'VALID')
features_per_box = tf.reshape(
features_per_box,
[batch_size, num_boxes, output_size[0], output_size[1], num_filters])
return features_per_box
def native_crop_and_resize(image, boxes, crop_size, scope=None):
"""Same as `matmul_crop_and_resize` but uses tf.image.crop_and_resize."""
def get_box_inds(proposals):
proposals_shape = proposals.get_shape().as_list()
if any(dim is None for dim in proposals_shape):
proposals_shape = tf.shape(proposals)
ones_mat = tf.ones(proposals_shape[:2], dtype=tf.int32)
multiplier = tf.expand_dims(
tf.range(start=0, limit=proposals_shape[0]), 1)
return tf.reshape(ones_mat * multiplier, [-1])
with tf.name_scope(scope, 'CropAndResize'):
cropped_regions = tf.image.crop_and_resize(
image, tf.reshape(boxes, [-1] + boxes.shape.as_list()[2:]),
get_box_inds(boxes), crop_size)
final_shape = tf.concat([tf.shape(boxes)[:2],
tf.shape(cropped_regions)[1:]], axis=0)
return tf.reshape(cropped_regions, final_shape)
def matmul_crop_and_resize(image, boxes, crop_size, extrapolation_value=0.0,
scope=None):
"""Matrix multiplication based implementation of the crop and resize op.
Extracts crops from the input image tensor and bilinearly resizes them
(possibly with aspect ratio change) to a common output size specified by
crop_size. This is more general than the crop_to_bounding_box op which
extracts a fixed size slice from the input image and does not allow
resizing or aspect ratio change.
Returns a tensor with crops from the input image at positions defined at
the bounding box locations in boxes. The cropped boxes are all resized
(with bilinear interpolation) to a fixed size = `[crop_height, crop_width]`.
The result is a 5-D tensor `[batch, num_boxes, crop_height, crop_width,
depth]`.
Note that this operation is meant to replicate the behavior of the standard
tf.image.crop_and_resize operation but there are a few differences.
Specifically:
1) There is no `box_indices` argument --- to run this op on multiple images,
one must currently call this op independently on each image.
2) The `crop_size` parameter is assumed to be statically defined.
Moreover, the number of boxes must be strictly nonzero.
Args:
image: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
`int16`, `int32`, `int64`, `half`, 'bfloat16', `float32`, `float64`.
A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
Both `image_height` and `image_width` need to be positive.
boxes: A `Tensor` of type `float32` or 'bfloat16'.
A 3-D tensor of shape `[batch, num_boxes, 4]`. The boxes are specified in
normalized coordinates and are of the form `[y1, x1, y2, x2]`. A
normalized coordinate value of `y` is mapped to the image coordinate at
`y * (image_height - 1)`, so as the `[0, 1]` interval of normalized image
height is mapped to `[0, image_height - 1] in image height coordinates.
We do allow y1 > y2, in which case the sampled crop is an up-down flipped
version of the original image. The width dimension is treated similarly.
Normalized coordinates outside the `[0, 1]` range are allowed, in which
case we use `extrapolation_value` to extrapolate the input image values.
crop_size: A list of two integers `[crop_height, crop_width]`. All
cropped image patches are resized to this size. The aspect ratio of the
image content is not preserved. Both `crop_height` and `crop_width` need
to be positive.
extrapolation_value: a float value to use for extrapolation.
scope: A name for the operation (optional).
Returns:
A 5-D tensor of shape `[batch, num_boxes, crop_height, crop_width, depth]`
"""
with tf.name_scope(scope, 'MatMulCropAndResize'):
box_levels = tf.zeros(boxes.shape.as_list()[:2], dtype=tf.int32)
return multilevel_roi_align([image],
boxes,
box_levels,
crop_size,
align_corners=True,
extrapolation_value=extrapolation_value)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.spatial_transform_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from object_detection.utils import spatial_transform_ops as spatial_ops
from object_detection.utils import test_case
class BoxGridCoordinateTest(test_case.TestCase):
def test_4x4_grid(self):
boxes = np.array([[[0., 0., 6., 6.]]], dtype=np.float32)
def graph_fn(boxes):
return spatial_ops.box_grid_coordinate_vectors(boxes, size_y=4, size_x=4)
grid_y, grid_x = self.execute(graph_fn, [boxes])
expected_grid_y = np.array([[[0.75, 2.25, 3.75, 5.25]]])
expected_grid_x = np.array([[[0.75, 2.25, 3.75, 5.25]]])
self.assertAllClose(expected_grid_y, grid_y)
self.assertAllClose(expected_grid_x, grid_x)
def test_2x2_grid(self):
def graph_fn(boxes):
return spatial_ops.box_grid_coordinate_vectors(boxes, size_x=2, size_y=2)
boxes = np.array([[[0., 0., 6., 3.],
[0., 0., 3., 6.]]], dtype=np.float32)
grid_y, grid_x = self.execute(graph_fn, [boxes])
expected_grid_y = np.array([[[1.5, 4.5],
[0.75, 2.25]]])
expected_grid_x = np.array([[[0.75, 2.25],
[1.5, 4.5]]])
self.assertAllClose(expected_grid_y, grid_y)
self.assertAllClose(expected_grid_x, grid_x)
def test_2x4_grid(self):
boxes = np.array([[[0., 0., 6., 6.]]], dtype=np.float32)
def graph_fn(boxes):
return spatial_ops.box_grid_coordinate_vectors(boxes, size_y=2, size_x=4)
grid_y, grid_x = self.execute(graph_fn, [boxes])
expected_grid_y = np.array([[[1.5, 4.5]]])
expected_grid_x = np.array([[[0.75, 2.25, 3.75, 5.25]]])
self.assertAllClose(expected_grid_y, grid_y)
self.assertAllClose(expected_grid_x, grid_x)
def test_2x4_grid_with_aligned_corner(self):
boxes = np.array([[[0., 0., 6., 6.]]], dtype=np.float32)
def graph_fn(boxes):
return spatial_ops.box_grid_coordinate_vectors(boxes, size_y=2, size_x=4,
align_corners=True)
grid_y, grid_x = self.execute(graph_fn, [boxes])
expected_grid_y = np.array([[[0, 6]]])
expected_grid_x = np.array([[[0, 2, 4, 6]]])
self.assertAllClose(expected_grid_y, grid_y)
self.assertAllClose(expected_grid_x, grid_x)
def test_offgrid_boxes(self):
boxes = np.array([[[1.2, 2.3, 7.2, 8.3]]], dtype=np.float32)
def graph_fn(boxes):
return spatial_ops.box_grid_coordinate_vectors(boxes, size_y=4, size_x=4)
grid_y, grid_x = self.execute(graph_fn, [boxes])
expected_grid_y = np.array([[[0.75, 2.25, 3.75, 5.25]]]) + 1.2
expected_grid_x = np.array([[[0.75, 2.25, 3.75, 5.25]]]) + 2.3
self.assertAllClose(expected_grid_y, grid_y)
self.assertAllClose(expected_grid_x, grid_x)
class FeatureGridCoordinateTest(test_case.TestCase):
def test_snap_box_points_to_nearest_4_pixels(self):
box_grid_y = np.array([[[1.5, 4.6]]], dtype=np.float32)
box_grid_x = np.array([[[2.4, 5.3]]], dtype=np.float32)
def graph_fn(box_grid_y, box_grid_x):
return spatial_ops.feature_grid_coordinate_vectors(box_grid_y, box_grid_x)
(feature_grid_y0,
feature_grid_x0, feature_grid_y1, feature_grid_x1) = self.execute(
graph_fn, [box_grid_y, box_grid_x])
expected_grid_y0 = np.array([[[1, 4]]])
expected_grid_y1 = np.array([[[2, 5]]])
expected_grid_x0 = np.array([[[2, 5]]])
expected_grid_x1 = np.array([[[3, 6]]])
self.assertAllEqual(expected_grid_y0, feature_grid_y0)
self.assertAllEqual(expected_grid_y1, feature_grid_y1)
self.assertAllEqual(expected_grid_x0, feature_grid_x0)
self.assertAllEqual(expected_grid_x1, feature_grid_x1)
def test_snap_box_points_outside_pixel_grid_to_nearest_neighbor(self):
box_grid_y = np.array([[[0.33, 1., 1.66]]], dtype=np.float32)
box_grid_x = np.array([[[-0.5, 1., 1.66]]], dtype=np.float32)
def graph_fn(box_grid_y, box_grid_x):
return spatial_ops.feature_grid_coordinate_vectors(box_grid_y, box_grid_x)
(feature_grid_y0,
feature_grid_x0, feature_grid_y1, feature_grid_x1) = self.execute(
graph_fn, [box_grid_y, box_grid_x])
expected_grid_y0 = np.array([[[0, 1, 1]]])
expected_grid_y1 = np.array([[[1, 2, 2]]])
expected_grid_x0 = np.array([[[-1, 1, 1]]])
expected_grid_x1 = np.array([[[0, 2, 2]]])
self.assertAllEqual(expected_grid_y0, feature_grid_y0)
self.assertAllEqual(expected_grid_y1, feature_grid_y1)
self.assertAllEqual(expected_grid_x0, feature_grid_x0)
self.assertAllEqual(expected_grid_x1, feature_grid_x1)
class RavelIndicesTest(test_case.TestCase):
def test_feature_point_indices(self):
feature_grid_y = np.array([[[1, 2, 4, 5],
[2, 3, 4, 5]]], dtype=np.int32)
feature_grid_x = np.array([[[1, 3, 4],
[2, 3, 4]]], dtype=np.int32)
num_feature_levels = 2
feature_height = 6
feature_width = 5
box_levels = np.array([[0, 1]], dtype=np.int32)
def graph_fn(feature_grid_y, feature_grid_x, box_levels):
return spatial_ops.ravel_indices(feature_grid_y, feature_grid_x,
num_feature_levels, feature_height,
feature_width, box_levels)
indices = self.execute(graph_fn,
[feature_grid_y, feature_grid_x, box_levels])
expected_indices = np.array([[[[6, 8, 9],
[11, 13, 14],
[21, 23, 24],
[26, 28, 29]],
[[42, 43, 44],
[47, 48, 49],
[52, 53, 54],
[57, 58, 59]]]])
self.assertAllEqual(expected_indices.flatten(), indices)
class MultiLevelRoIAlignTest(test_case.TestCase):
def test_perfectly_aligned_cell_center_and_feature_pixels(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[2, 2])
image = np.arange(25).reshape(1, 5, 5, 1).astype(np.float32)
boxes = np.array([[[0, 0, 1.0, 1.0]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = [[[[[6], [8]],
[[16], [18]]]]]
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(crop_output, expected_output)
def test_interpolation_with_4_points_per_bin(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[1, 1],
num_samples_per_cell_y=2,
num_samples_per_cell_x=2)
image = np.array([[[[1], [2], [3], [4]],
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [15], [16]]]],
dtype=np.float32)
boxes = np.array([[[1./3, 1./3, 2./3, 2./3]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = [[[[[(7.25 + 7.75 + 9.25 + 9.75) / 4]]]]]
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(expected_output, crop_output)
def test_1x1_crop_on_2x2_features(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[1, 1])
image = np.array([[[[1], [2]],
[[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(crop_output, expected_output)
def test_3x3_crops_on_2x2_features(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[3, 3])
image = np.array([[[[1], [2]],
[[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = [[[[[9./6], [11./6], [13./6]],
[[13./6], [15./6], [17./6]],
[[17./6], [19./6], [21./6]]]]]
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(crop_output, expected_output)
def test_2x2_crops_on_3x3_features(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[2, 2])
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]],
dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]],
dtype=np.float32)
box_levels = np.array([[0, 0]], dtype=np.int32)
expected_output = [[[[[3], [4]],
[[6], [7]]],
[[[2.], [2.5]],
[[3.5], [4.]]]]]
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(crop_output, expected_output)
def test_2x2_crop_on_4x4_features(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[2, 2])
image = np.array([[[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
[[8], [9], [10], [11]],
[[12], [13], [14], [15]]]],
dtype=np.float32)
boxes = np.array([[[0, 0, 2./3, 2./3],
[0, 0, 2./3, 1.0]]],
dtype=np.float32)
box_levels = np.array([[0, 0]], dtype=np.int32)
expected_output = np.array([[[[[2.5], [3.5]],
[[6.5], [7.5]]],
[[[2.75], [4.25]],
[[6.75], [8.25]]]]])
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(expected_output, crop_output)
def test_extrapolate_3x3_crop_on_2x2_features(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[3, 3])
image = np.array([[[[1], [2]],
[[3], [4]]]], dtype=np.float32)
boxes = np.array([[[-1, -1, 2, 2]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = np.array([[[[[0.25], [0.75], [0.5]],
[[1.0], [2.5], [1.5]],
[[0.75], [1.75], [1]]]]])
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(expected_output, crop_output)
def test_extrapolate_with_non_zero_value(self):
def graph_fn(image, boxes, levels):
return spatial_ops.multilevel_roi_align([image],
boxes,
levels,
output_size=[3, 3],
extrapolation_value=2.0)
image = np.array([[[[4], [4]],
[[4], [4]]]], dtype=np.float32)
boxes = np.array([[[-1, -1, 2, 2]]], dtype=np.float32)
box_levels = np.array([[0]], dtype=np.int32)
expected_output = np.array([[[[[2.5], [3.0], [2.5]],
[[3.0], [4.0], [3.0]],
[[2.5], [3.0], [2.5]]]]])
crop_output = self.execute(graph_fn, [image, boxes, box_levels])
self.assertAllClose(expected_output, crop_output)
def test_multilevel_roi_align(self):
image_size = 640
fpn_min_level = 2
fpn_max_level = 5
batch_size = 1
output_size = [2, 2]
num_filters = 1
features = []
for level in range(fpn_min_level, fpn_max_level + 1):
feat_size = int(image_size / 2**level)
features.append(
float(level) *
np.ones([batch_size, feat_size, feat_size, num_filters],
dtype=np.float32))
boxes = np.array(
[
[
[0, 0, 111, 111], # Level 2.
[0, 0, 113, 113], # Level 3.
[0, 0, 223, 223], # Level 3.
[0, 0, 225, 225], # Level 4.
[0, 0, 449, 449] # Level 5.
],
],
dtype=np.float32) / image_size
levels = np.array([[0, 1, 1, 2, 3]], dtype=np.int32)
def graph_fn(feature1, feature2, feature3, feature4, boxes, levels):
roi_features = spatial_ops.multilevel_roi_align(
[feature1, feature2, feature3, feature4],
boxes,
levels,
output_size)
return roi_features
roi_features = self.execute(graph_fn, features + [boxes, levels])
self.assertAllClose(roi_features[0][0], 2 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0][1], 3 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0][2], 3 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0][3], 4 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0][4], 5 * np.ones((2, 2, 1)))
def test_large_input(self):
if test_case.FLAGS.tpu_test:
input_size = 1408
min_level = 2
max_level = 6
batch_size = 2
num_boxes = 512
num_filters = 256
output_size = [7, 7]
with self.test_session() as sess:
features = []
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features.append(tf.constant(
np.reshape(
np.arange(
batch_size * feat_size * feat_size * num_filters,
dtype=np.float32),
[batch_size, feat_size, feat_size, num_filters]),
dtype=tf.bfloat16))
boxes = np.array([
[[0, 0, 256, 256]]*num_boxes,
], dtype=np.float32) / input_size
boxes = np.tile(boxes, [batch_size, 1, 1])
tf_boxes = tf.constant(boxes)
tf_levels = tf.random_uniform([batch_size, num_boxes], maxval=5,
dtype=tf.int32)
def crop_and_resize_fn():
return spatial_ops.multilevel_roi_align(
features, tf_boxes, tf_levels, output_size)
tpu_crop_and_resize_fn = tf.contrib.tpu.rewrite(crop_and_resize_fn)
sess.run(tf.contrib.tpu.initialize_system())
sess.run(tf.global_variables_initializer())
roi_features = sess.run(tpu_crop_and_resize_fn)
self.assertEqual(roi_features[0].shape,
(batch_size, num_boxes, output_size[0], output_size[1],
num_filters))
sess.run(tf.contrib.tpu.shutdown_system())
class MatMulCropAndResizeTest(test_case.TestCase):
def testMatMulCropAndResize2x2To1x1(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To1x1Flipped(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[1, 1])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[2.5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To3x3(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1]]], dtype=np.float32)
expected_output = [[[[[1.0], [1.5], [2.0]],
[[2.0], [2.5], [3.0]],
[[3.0], [3.5], [4.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize2x2To3x3Flipped(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[3, 3])
image = np.array([[[[1], [2]], [[3], [4]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0]]], dtype=np.float32)
expected_output = [[[[[4.0], [3.5], [3.0]],
[[3.0], [2.5], [2.0]],
[[2.0], [1.5], [1.0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1], [3]], [[7], [9]]],
[[[1], [2]], [[4], [5]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testBatchMatMulCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testMatMulCropAndResize3x3To2x2Flipped(self):
def graph_fn(image, boxes):
return spatial_ops.matmul_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]]], dtype=np.float32)
boxes = np.array([[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[9], [7]], [[3], [1]]],
[[[5], [4]], [[2], [1]]]]]
crop_output = self.execute(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
def testInvalidInputShape(self):
image = tf.constant([[[1], [2]], [[3], [4]]], dtype=tf.float32)
boxes = tf.constant([[-1, -1, 1, 1]], dtype=tf.float32)
crop_size = [4, 4]
with self.assertRaises(ValueError):
spatial_ops.matmul_crop_and_resize(image, boxes, crop_size)
class NativeCropAndResizeTest(test_case.TestCase):
def testBatchCropAndResize3x3To2x2_2Channels(self):
def graph_fn(image, boxes):
return spatial_ops.native_crop_and_resize(image, boxes, crop_size=[2, 2])
image = np.array([[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]],
[[[1, 0], [2, 1], [3, 2]],
[[4, 3], [5, 4], [6, 5]],
[[7, 6], [8, 7], [9, 8]]]], dtype=np.float32)
boxes = np.array([[[0, 0, 1, 1],
[0, 0, .5, .5]],
[[1, 1, 0, 0],
[.5, .5, 0, 0]]], dtype=np.float32)
expected_output = [[[[[1, 0], [3, 2]], [[7, 6], [9, 8]]],
[[[1, 0], [2, 1]], [[4, 3], [5, 4]]]],
[[[[9, 8], [7, 6]], [[3, 2], [1, 0]]],
[[[5, 4], [4, 3]], [[2, 1], [1, 0]]]]]
crop_output = self.execute_cpu(graph_fn, [image, boxes])
self.assertAllClose(crop_output, expected_output)
if __name__ == '__main__':
tf.test.main()
......@@ -19,6 +19,21 @@ The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
"""
def get_dim_as_int(dim):
"""Utility to get v1 or v2 TensorShape dim as an int.
Args:
dim: The TensorShape dimension to get as an int
Returns:
None or an int.
"""
try:
return dim.value
except AttributeError:
return dim
def get_batch_size(tensor_shape):
"""Returns batch size from the tensor shape.
......@@ -29,7 +44,7 @@ def get_batch_size(tensor_shape):
An integer representing the batch size of the tensor.
"""
tensor_shape.assert_has_rank(rank=4)
return tensor_shape[0].value
return get_dim_as_int(tensor_shape[0])
def get_height(tensor_shape):
......@@ -42,7 +57,7 @@ def get_height(tensor_shape):
An integer representing the height of the tensor.
"""
tensor_shape.assert_has_rank(rank=4)
return tensor_shape[1].value
return get_dim_as_int(tensor_shape[1])
def get_width(tensor_shape):
......@@ -55,7 +70,7 @@ def get_width(tensor_shape):
An integer representing the width of the tensor.
"""
tensor_shape.assert_has_rank(rank=4)
return tensor_shape[2].value
return get_dim_as_int(tensor_shape[2])
def get_depth(tensor_shape):
......@@ -68,4 +83,4 @@ def get_depth(tensor_shape):
An integer representing the depth of the tensor.
"""
tensor_shape.assert_has_rank(rank=4)
return tensor_shape[3].value
return get_dim_as_int(tensor_shape[3])
......@@ -1005,9 +1005,13 @@ class EvalMetricOpsVisualization(object):
lambda: tf.summary.image(summary_name, image),
lambda: tf.constant(''))
update_op = tf.py_func(self.add_images, [[images[0]]], [])
image_tensors = tf.py_func(
get_images, [], [tf.uint8] * self._max_examples_to_draw)
if tf.executing_eagerly():
update_op = self.add_images([[images[0]]])
image_tensors = get_images()
else:
update_op = tf.py_func(self.add_images, [[images[0]]], [])
image_tensors = tf.py_func(
get_images, [], [tf.uint8] * self._max_examples_to_draw)
eval_metric_ops = {}
for i, image in enumerate(image_tensors):
summary_name = self._summary_name_prefix + '/' + str(i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment