Commit 05b6ea2d authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334647537
parent ff2a2408
......@@ -178,8 +178,7 @@ class Parser(parser.Parser):
self._unmatched_threshold)
(cls_targets, box_targets, cls_weights,
box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes,
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
......@@ -244,8 +243,7 @@ class Parser(parser.Parser):
self._unmatched_threshold)
(cls_targets, box_targets, cls_weights,
box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes,
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
# If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16:
......
......@@ -137,7 +137,7 @@ class AnchorLabeler(object):
self.matcher = keras_cv.ops.BoxMatcher(
positive_threshold=match_threshold,
negative_threshold=unmatched_threshold,
force_match_for_each_row=True)
force_match_for_each_col=True)
self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
def label_anchors(self, anchor_boxes, gt_boxes, gt_labels):
......@@ -173,10 +173,17 @@ class AnchorLabeler(object):
for anchors in anchor_boxes.values():
flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4]))
flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0)
similarity_matrix = self.similarity_calc(gt_boxes, flattened_anchor_boxes)
match_results = self.matcher(similarity_matrix)
cls_targets, box_targets, cls_weights, box_weights = self.anchor_labeler(
gt_boxes, gt_labels, match_results)
similarity_matrix = self.similarity_calc(flattened_anchor_boxes, gt_boxes)
match_indices, match_indicators = self.matcher(similarity_matrix)
mask = tf.less_equal(match_indicators, 0)
cls_mask = tf.expand_dims(mask, -1)
cls_targets = self.anchor_labeler(gt_labels, match_indices, cls_mask, -1)
box_mask = tf.tile(cls_mask, [1, 4])
box_targets = self.anchor_labeler(gt_boxes, match_indices, box_mask)
weights = tf.squeeze(tf.ones_like(gt_labels, dtype=tf.float32), -1)
box_weights = self.anchor_labeler(weights, match_indices, mask)
ignore_mask = tf.equal(match_indicators, -2)
cls_weights = self.anchor_labeler(weights, match_indices, ignore_mask)
box_targets_list = box_list.BoxList(box_targets)
anchor_box_list = box_list.BoxList(flattened_anchor_boxes)
box_targets = self.box_coder.encode(box_targets_list, anchor_box_list)
......@@ -268,19 +275,19 @@ class RpnAnchorLabeler(AnchorLabeler):
for anchors in anchor_boxes.values():
flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4]))
flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0)
similarity_matrix = self.similarity_calc(gt_boxes, flattened_anchor_boxes)
match_results = self.matcher(similarity_matrix)
# cls_targets, cls_weights, box_weights are not used.
_, box_targets, _, _ = self.anchor_labeler(
gt_boxes, gt_labels, match_results)
similarity_matrix = self.similarity_calc(flattened_anchor_boxes, gt_boxes)
match_indices, match_indicators = self.matcher(similarity_matrix)
box_mask = tf.tile(tf.expand_dims(tf.less_equal(match_indicators, 0), -1),
[1, 4])
box_targets = self.anchor_labeler(gt_boxes, match_indices, box_mask)
box_targets_list = box_list.BoxList(box_targets)
anchor_box_list = box_list.BoxList(flattened_anchor_boxes)
box_targets = self.box_coder.encode(box_targets_list, anchor_box_list)
# Zero out the unmatched and ignored regression targets.
num_matches = match_results.shape.as_list()[0] or tf.shape(match_results)[0]
num_matches = match_indices.shape.as_list()[0] or tf.shape(match_indices)[0]
unmatched_ignored_box_targets = tf.zeros([num_matches, 4], dtype=tf.float32)
matched_anchors_mask = tf.greater_equal(match_results, 0)
matched_anchors_mask = tf.greater_equal(match_indicators, 0)
# To broadcast matched_anchors_mask to the same shape as
# matched_reg_targets.
matched_anchors_mask = tf.tile(
......@@ -290,7 +297,7 @@ class RpnAnchorLabeler(AnchorLabeler):
unmatched_ignored_box_targets)
# score_targets contains the subsampled positive and negative anchors.
score_targets, _, _ = self._get_rpn_samples(match_results)
score_targets, _, _ = self._get_rpn_samples(match_indicators)
# Unpacks labels.
score_targets_dict = unpack_targets(score_targets, anchor_boxes)
......
......@@ -20,55 +20,20 @@ import tensorflow as tf
class AnchorLabeler:
"""Labeler for dense object detector."""
def __init__(
self,
positive_class_weight=1.0,
positive_regression_weight=1.0,
negative_class_weight=1.0,
negative_regression_weight=0.0,
negative_class_label=-1,
ignore_class_label=-2,
negative_regression_label=0.,
ignore_regression_label=0.):
"""Constructs Anchor Labeler.
Args:
positive_class_weight: classification weight to be associated to positive
matched anchor. Defaults to 1.0.
positive_regression_weight: regression weight to be associated to positive
matched anchor. Defaults to 1.0.
negative_class_weight: classification weight to be associated to negative
matched anchor. Default to 1.0
negative_regression_weight: classification weight to be associated to
negative matched anchor. Default to 0.0.
negative_class_label: An integer for classification label to be associated
for negative matched anchor. Defaults to -1.
ignore_class_label: An integer for classification label to be associated
for ignored anchor. Defaults to -2.
negative_regression_label: A float for regression label to be associated
for negative matched anchor. Defaults to 0.
ignore_regression_label: A float for regression label to be associated
for ignored anchor. Defaults to 0.
"""
self.positive_class_weight = positive_class_weight
self.positive_regression_weight = positive_regression_weight
self.negative_class_weight = negative_class_weight
self.negative_regression_weight = negative_regression_weight
self.negative_class_label = negative_class_label
self.ignore_class_label = ignore_class_label
self.negative_regression_label = negative_regression_label
self.ignore_regression_label = ignore_regression_label
def __call__(self, boxes, labels, matches):
def __call__(self, labels, match_indices, mask, mask_val=0.0):
"""Labels anchors with ground truth inputs.
B: batch_size
N: number of groundtruth boxes.
Args:
boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
For each row, it stores [y0, x0, y1, x1] for four corners of a box.
labels: An integer tensor with shape [N, 1] representing groundtruth
classes.
matches: An integer tensor with shape [N] representing match results, must
be -1 for negative matched anchor, and -2 for ignored anchor.
labels: An integer tensor with shape [N, 1] or [B, N, 1] representing
groundtruth labels.
match_indices: An integer tensor with shape [N] or [B, N] representing
match label index.
mask: An integer tensor with shape [N] or [B, N] representing match
labels, e.g., 1 for positive, -1 for negative, -2 for ignore.
mask_val: An integer to fill in for mask.
Returns:
class_targets: A integer Tensor with shape [num_anchors].
......@@ -82,65 +47,43 @@ class AnchorLabeler:
1.0 for positive matched anchors, and 0.0 for negative and ignored
anchors.
"""
class_targets = self._gather_based_on_match(
matches, tf.cast(labels, tf.int32),
negative_value=tf.constant([self.negative_class_label], tf.int32),
ignored_value=tf.constant([self.ignore_class_label], tf.int32))
negative_reg_value = tf.constant(
[self.negative_regression_label] * 4, dtype=tf.float32)
ignore_reg_value = tf.constant(
[self.ignore_regression_label] * 4, dtype=tf.float32)
reg_targets = self._gather_based_on_match(
matches, boxes, negative_reg_value, ignore_reg_value)
num_gt_boxes = boxes.shape.as_list()[0] or tf.shape(boxes)[0]
groundtruth_class_weights = self.positive_class_weight * tf.ones(
[num_gt_boxes], dtype=tf.float32)
class_weights = self._gather_based_on_match(
matches, groundtruth_class_weights,
negative_value=self.negative_class_weight,
ignored_value=0.)
groundtruth_reg_weights = self.positive_regression_weight * tf.ones(
[num_gt_boxes], dtype=tf.float32)
reg_weights = self._gather_based_on_match(
matches, groundtruth_reg_weights,
negative_value=self.negative_regression_weight, ignored_value=0.)
return class_targets, reg_targets, class_weights, reg_weights
def _gather_based_on_match(
self, matches, inputs, negative_value, ignored_value):
"""Gathers elements from `input_tensor` based on match results.
For columns that are matched to a row, gathered_tensor[col] is set to
input_tensor[match[col]]. For columns that are unmatched,
gathered_tensor[col] is set to negative_value. Finally, for columns that
are ignored gathered_tensor[col] is set to ignored_value.
Note that the input_tensor.shape[1:] must match with unmatched_value.shape
and ignored_value.shape
Args:
matches: A integer tensor with shape [N] representing the
matching results of anchors. (1) match_results[i]>=0,
meaning that column i is matched with row match_results[i].
(2) match_results[i]=-1, meaning that column i is not matched.
(3) match_results[i]=-2, meaning that column i is ignored.
inputs: Tensor to gather values from.
negative_value: Constant tensor value for unmatched columns.
ignored_value: Constant tensor value for ignored columns.
Returns:
gathered_tensor: A tensor containing values gathered from input_tensor.
The shape of the gathered tensor is [match.shape[0]] +
input_tensor.shape[1:].
"""
inputs = tf.concat(
[tf.stack([ignored_value, negative_value]), inputs], axis=0)
gather_indices = tf.maximum(matches + 2, 0)
gathered_tensor = tf.gather(inputs, gather_indices)
return gathered_tensor
if len(labels.shape) <= 2:
return self._gather_unbatched(labels, match_indices, mask, mask_val)
elif len(labels.shape) == 3:
return self._gather_batched(labels, match_indices, mask, mask_val)
def _gather_unbatched(self, labels, match_indices, mask, mask_val):
"""Gather based on unbatched labels and boxes."""
num_gt_boxes = tf.shape(labels)[0]
masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
mask, dtype=labels.dtype)
def _assign_when_rows_empty():
return masked_targets
def _assign_when_rows_not_empty():
targets = tf.gather(labels, match_indices)
return tf.where(mask, masked_targets, targets)
return tf.cond(tf.greater(num_gt_boxes, 0),
_assign_when_rows_not_empty,
_assign_when_rows_empty)
def _gather_batched(self, labels, match_indices, mask, mask_val):
"""Gather based on batched labels."""
batch_size = labels.shape[0]
if batch_size == 1:
result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
tf.squeeze(mask, axis=0), mask_val)
return tf.expand_dims(result, axis=0)
else:
indices_shape = tf.shape(match_indices)
indices_dtype = match_indices.dtype
batch_indices = (tf.expand_dims(
tf.range(indices_shape[0], dtype=indices_dtype), axis=-1) *
tf.ones([1, indices_shape[-1]], dtype=indices_dtype))
gather_nd_indices = tf.stack(
[batch_indices, match_indices], axis=-1)
targets = tf.gather_nd(labels, gather_nd_indices)
return targets
......@@ -40,7 +40,9 @@ class BoxMatcher:
self,
positive_threshold,
negative_threshold=None,
force_match_for_each_row=False,
force_match_for_each_col=False,
negative_lower_than_ignore=True,
positive_value=1,
negative_value=-1,
ignore_value=-2):
"""Construct BoxMatcher.
......@@ -53,11 +55,15 @@ class BoxMatcher:
sim < negative_threshold or
positive_threshold > sim >= negative_threshold.
Defaults to positive_threshold when set to None.
force_match_for_each_row: If True, ensures that each row is matched to
at least one column (which is not guaranteed otherwise if the
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False.
negative_value: An integer to fill for negative matches.
ignore_value: An integer to fill for ignored matches.
negative_lower_than_ignore: If True, the threshold is
positive|ignore|negative, else positive|negative|ignore. Defaults to
True.
positive_value: An integer to fill for positive match labels.
negative_value: An integer to fill for negative match labels.
ignore_value: An integer to fill for ignored match labels.
Raises:
ValueError: If negative_threshold > positive_threshold.
......@@ -71,9 +77,11 @@ class BoxMatcher:
'to positive_threshold')
self._negative_threshold = negative_threshold
self._positive_value = positive_value
self._negative_value = negative_value
self._ignore_value = ignore_value
self._force_match_for_each_row = force_match_for_each_row
self._force_match_for_each_col = force_match_for_each_col
self._negative_lower_than_ignore = negative_lower_than_ignore
def __call__(self, similarity_matrix):
"""Tries to match each column of the similarity matrix to a row.
......@@ -83,11 +91,20 @@ class BoxMatcher:
similarity metric.
Returns:
A integer tensor with corresponding match indices for each of M columns,
for positive match, the match result will be the corresponding row index,
for negative match, the match will be `negative_value`, for ignored match,
the match result will be `ignore_value`.
A integer tensor of shape [N] with corresponding match indices for each
of M columns, for positive match, the match result will be the
corresponding row index, for negative match, the match will be
`negative_value`, for ignored match, the match result will be
`ignore_value`.
"""
squeeze_result = False
if len(similarity_matrix.shape) == 2:
squeeze_result = True
similarity_matrix = tf.expand_dims(similarity_matrix, axis=0)
static_shape = similarity_matrix.shape.as_list()
num_rows = static_shape[1] or tf.shape(similarity_matrix)[1]
batch_size = static_shape[0] or tf.shape(similarity_matrix)[0]
def _match_when_rows_are_empty():
"""Performs matching when the rows of similarity matrix are empty.
......@@ -98,9 +115,11 @@ class BoxMatcher:
Returns:
matches: int32 tensor indicating the row each column matches to.
"""
static_shape = similarity_matrix.shape.as_list()
num_cols = static_shape[1] or tf.shape(similarity_matrix)[1]
return -1 * tf.ones([num_cols], dtype=tf.int32)
with tf.name_scope('empty_gt_boxes'):
matches = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_labels = self._negative_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32)
return matches, match_labels
def _match_when_rows_are_non_empty():
"""Performs matching when the rows of similarity matrix are non empty.
......@@ -109,50 +128,75 @@ class BoxMatcher:
matches: int32 tensor indicating the row each column matches to.
"""
# Matches for each column
matches = tf.argmax(input=similarity_matrix, axis=0, output_type=tf.int32)
with tf.name_scope('non_empty_gt_boxes'):
matches = tf.argmax(similarity_matrix, axis=-1, output_type=tf.int32)
# Deal with matched and unmatched threshold
if self._positive_threshold is not None:
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=0)
below_negative_threshold = tf.greater(self._negative_threshold,
matched_vals)
matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_labels = self._positive_value * tf.ones(
[batch_size, num_rows], tf.int32)
positive_threshold = tf.cast(
self._positive_threshold, matched_vals.dtype)
negative_threshold = tf.cast(
self._negative_threshold, matched_vals.dtype)
below_negative_threshold = tf.greater(negative_threshold, matched_vals)
between_thresholds = tf.logical_and(
tf.greater_equal(matched_vals, self._negative_threshold),
tf.greater(self._positive_threshold, matched_vals))
matches = self._set_values_using_indicator(matches,
below_negative_threshold,
self._negative_value)
matches = self._set_values_using_indicator(matches,
between_thresholds,
self._ignore_value)
if self._force_match_for_each_row:
num_gt_boxes = similarity_matrix.shape.as_list()[1] or tf.shape(
similarity_matrix)[1]
force_match_column_ids = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32)
force_match_column_indicators = tf.one_hot(
force_match_column_ids, depth=num_gt_boxes)
force_match_row_ids = tf.argmax(
input=force_match_column_indicators, axis=0, output_type=tf.int32)
force_match_column_mask = tf.cast(
tf.reduce_max(force_match_column_indicators, axis=0),
tf.bool)
final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches)
return final_matches
else:
return matches
num_gt_boxes = similarity_matrix.shape.as_list()[0] or tf.shape(
similarity_matrix)[0]
return tf.cond(
tf.greater_equal(matched_vals, negative_threshold),
tf.greater(positive_threshold, matched_vals))
if self._negative_lower_than_ignore:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._negative_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._ignore_value)
else:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._ignore_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._negative_value)
if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best
# matching row (anchor).
force_match_column_ids = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32)
# [batch_size, M, N]
force_match_column_indicators = tf.one_hot(
force_match_column_ids, depth=num_rows)
# [batch_size, N], for each row (anchor), find the largest column
# index for groundtruth box
force_match_row_ids = tf.argmax(
input=force_match_column_indicators, axis=1, output_type=tf.int32)
# [batch_size, N]
force_match_column_mask = tf.cast(
tf.reduce_max(force_match_column_indicators, axis=1),
tf.bool)
# [batch_size, N]
final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches)
final_matched_labels = tf.where(
force_match_column_mask,
self._positive_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32),
matched_labels)
return final_matches, final_matched_labels
else:
return matches, matched_labels
num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[-1]
result_match, result_match_labels = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty)
if squeeze_result:
result_match = tf.squeeze(result_match, axis=0)
result_match_labels = tf.squeeze(result_match_labels, axis=0)
return result_match, result_match_labels
def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val.
......
......@@ -20,60 +20,81 @@ import tensorflow as tf
def area(box):
"""Computes area of boxes.
B: batch_size
N: number of boxes
Args:
box: a float Tensor with [N, 4].
box: a float Tensor with [N, 4], or [B, N, 4].
Returns:
a float tensor with [N].
a float Tensor with [N], or [B, N]
"""
with tf.name_scope('Area'):
y_min, x_min, y_max, x_max = tf.split(
value=box, num_or_size_splits=4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
value=box, num_or_size_splits=4, axis=-1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), axis=-1)
def intersection(box1, box2):
def intersection(gt_boxes, boxes):
"""Compute pairwise intersection areas between boxes.
B: batch_size
N: number of groundtruth boxes.
M: number of anchor boxes.
Args:
box1: a float Tensor with [N, 4].
box2: a float Tensor with [M, 4].
gt_boxes: a float Tensor with [N, 4], or [B, N, 4]
boxes: a float Tensor with [M, 4], or [B, M, 4]
Returns:
a float tensor with shape [N, M] representing pairwise intersections
a float Tensor with shape [N, M] or [B, N, M] representing pairwise
intersections.
"""
with tf.name_scope('Intersection'):
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=box1, num_or_size_splits=4, axis=1)
value=gt_boxes, num_or_size_splits=4, axis=-1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=box2, num_or_size_splits=4, axis=1)
y_min_max = tf.minimum(y_max1, tf.transpose(a=y_max2))
y_max_min = tf.maximum(y_min1, tf.transpose(a=y_min2))
intersect_heights = tf.maximum(0.0, y_min_max - y_max_min)
x_min_max = tf.minimum(x_max1, tf.transpose(a=x_max2))
x_max_min = tf.maximum(x_min1, tf.transpose(a=x_min2))
intersect_widths = tf.maximum(0.0, x_min_max - x_max_min)
value=boxes, num_or_size_splits=4, axis=-1)
boxes_rank = len(boxes.shape)
perm = [1, 0] if boxes_rank == 2 else [0, 2, 1]
# [N, M] or [B, N, M]
y_min_max = tf.minimum(y_max1, tf.transpose(y_max2, perm))
y_max_min = tf.maximum(y_min1, tf.transpose(y_min2, perm))
x_min_max = tf.minimum(x_max1, tf.transpose(x_max2, perm))
x_max_min = tf.maximum(x_min1, tf.transpose(x_min2, perm))
intersect_heights = y_min_max - y_max_min
intersect_widths = x_min_max - x_max_min
zeros_t = tf.cast(0, intersect_heights.dtype)
intersect_heights = tf.maximum(zeros_t, intersect_heights)
intersect_widths = tf.maximum(zeros_t, intersect_widths)
return intersect_heights * intersect_widths
def iou(box1, box2):
def iou(gt_boxes, boxes):
"""Computes pairwise intersection-over-union between box collections.
Args:
box1: a float Tensor with [N, 4].
box2: a float Tensor with [M, 4].
gt_boxes: a float Tensor with [N, 4].
boxes: a float Tensor with [M, 4].
Returns:
a tensor with shape [N, M] representing pairwise iou scores.
a Tensor with shape [N, M] representing pairwise iou scores.
"""
intersections = intersection(box1, box2)
areas1 = area(box1)
areas2 = area(box2)
unions = (
tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
return tf.where(
tf.equal(intersections, 0.0), tf.zeros_like(intersections),
tf.truediv(intersections, unions))
with tf.name_scope('IOU'):
intersections = intersection(gt_boxes, boxes)
gt_boxes_areas = area(gt_boxes)
boxes_areas = area(boxes)
boxes_rank = len(boxes_areas.shape)
boxes_axis = 1 if (boxes_rank == 2) else 0
gt_boxes_areas = tf.expand_dims(gt_boxes_areas, -1)
boxes_areas = tf.expand_dims(boxes_areas, boxes_axis)
unions = gt_boxes_areas + boxes_areas
unions = unions - intersections
return tf.where(
tf.equal(intersections, 0.0), tf.zeros_like(intersections),
tf.truediv(intersections, unions))
class IouSimilarity():
......@@ -84,19 +105,37 @@ class IouSimilarity():
def __call__(self, groundtruth_boxes, anchors):
"""Compute pairwise IOU similarity between ground truth boxes and anchors.
B: batch_size
N: Number of groundtruth boxes.
M: Number of anchor boxes.
Args:
groundtruth_boxes: a float Tensor with N boxes.
anchors: a float Tensor with M boxes.
groundtruth_boxes: a float Tensor with M boxes.
anchors: a float Tensor with N boxes.
Returns:
A tensor with shape [N, M] representing pairwise iou scores.
A Tensor with shape [M, N] or [B, M, N] representing pairwise
iou scores, anchor per row and groundtruth_box per colulmn.
Input shape:
groundtruth_boxes: [N, 4]
anchors: [M, 4]
groundtruth_boxes: [N, 4], or [B, N, 4]
anchors: [M, 4], or [B, M, 4]
Output shape:
[N, M]
[M, N], or [B, M, N]
"""
with tf.name_scope('IOU'):
return iou(groundtruth_boxes, anchors)
groundtruth_rank = len(groundtruth_boxes.shape)
anchor_rank = len(anchors.shape)
if groundtruth_rank < 2 or groundtruth_rank > 3:
raise ValueError('`groudtruth_boxes` must be rank 2 or 3, got {}'.format(
groundtruth_rank))
if anchor_rank < 2 or anchor_rank > 3:
raise ValueError('`anchors` must be rank 2 or 3, got {}'.format(
anchor_rank))
if groundtruth_rank < anchor_rank:
raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is '
'batched is not a valid use case, got groundtruth_box '
'rank {}, and anchors rank {}'.format(
groundtruth_rank, anchor_rank))
return iou(groundtruth_boxes, anchors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment