Commit 05b6ea2d authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 334647537
parent ff2a2408
...@@ -178,8 +178,7 @@ class Parser(parser.Parser): ...@@ -178,8 +178,7 @@ class Parser(parser.Parser):
self._unmatched_threshold) self._unmatched_threshold)
(cls_targets, box_targets, cls_weights, (cls_targets, box_targets, cls_weights,
box_weights) = anchor_labeler.label_anchors( box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes, anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
# If bfloat16 is used, casts input image to tf.bfloat16. # If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16: if self._use_bfloat16:
...@@ -244,8 +243,7 @@ class Parser(parser.Parser): ...@@ -244,8 +243,7 @@ class Parser(parser.Parser):
self._unmatched_threshold) self._unmatched_threshold)
(cls_targets, box_targets, cls_weights, (cls_targets, box_targets, cls_weights,
box_weights) = anchor_labeler.label_anchors( box_weights) = anchor_labeler.label_anchors(
anchor_boxes, boxes, anchor_boxes, boxes, tf.expand_dims(classes, axis=1))
tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
# If bfloat16 is used, casts input image to tf.bfloat16. # If bfloat16 is used, casts input image to tf.bfloat16.
if self._use_bfloat16: if self._use_bfloat16:
......
...@@ -137,7 +137,7 @@ class AnchorLabeler(object): ...@@ -137,7 +137,7 @@ class AnchorLabeler(object):
self.matcher = keras_cv.ops.BoxMatcher( self.matcher = keras_cv.ops.BoxMatcher(
positive_threshold=match_threshold, positive_threshold=match_threshold,
negative_threshold=unmatched_threshold, negative_threshold=unmatched_threshold,
force_match_for_each_row=True) force_match_for_each_col=True)
self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder() self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
def label_anchors(self, anchor_boxes, gt_boxes, gt_labels): def label_anchors(self, anchor_boxes, gt_boxes, gt_labels):
...@@ -173,10 +173,17 @@ class AnchorLabeler(object): ...@@ -173,10 +173,17 @@ class AnchorLabeler(object):
for anchors in anchor_boxes.values(): for anchors in anchor_boxes.values():
flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4])) flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4]))
flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0) flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0)
similarity_matrix = self.similarity_calc(gt_boxes, flattened_anchor_boxes) similarity_matrix = self.similarity_calc(flattened_anchor_boxes, gt_boxes)
match_results = self.matcher(similarity_matrix) match_indices, match_indicators = self.matcher(similarity_matrix)
cls_targets, box_targets, cls_weights, box_weights = self.anchor_labeler( mask = tf.less_equal(match_indicators, 0)
gt_boxes, gt_labels, match_results) cls_mask = tf.expand_dims(mask, -1)
cls_targets = self.anchor_labeler(gt_labels, match_indices, cls_mask, -1)
box_mask = tf.tile(cls_mask, [1, 4])
box_targets = self.anchor_labeler(gt_boxes, match_indices, box_mask)
weights = tf.squeeze(tf.ones_like(gt_labels, dtype=tf.float32), -1)
box_weights = self.anchor_labeler(weights, match_indices, mask)
ignore_mask = tf.equal(match_indicators, -2)
cls_weights = self.anchor_labeler(weights, match_indices, ignore_mask)
box_targets_list = box_list.BoxList(box_targets) box_targets_list = box_list.BoxList(box_targets)
anchor_box_list = box_list.BoxList(flattened_anchor_boxes) anchor_box_list = box_list.BoxList(flattened_anchor_boxes)
box_targets = self.box_coder.encode(box_targets_list, anchor_box_list) box_targets = self.box_coder.encode(box_targets_list, anchor_box_list)
...@@ -268,19 +275,19 @@ class RpnAnchorLabeler(AnchorLabeler): ...@@ -268,19 +275,19 @@ class RpnAnchorLabeler(AnchorLabeler):
for anchors in anchor_boxes.values(): for anchors in anchor_boxes.values():
flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4])) flattened_anchor_boxes.append(tf.reshape(anchors, [-1, 4]))
flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0) flattened_anchor_boxes = tf.concat(flattened_anchor_boxes, axis=0)
similarity_matrix = self.similarity_calc(gt_boxes, flattened_anchor_boxes) similarity_matrix = self.similarity_calc(flattened_anchor_boxes, gt_boxes)
match_results = self.matcher(similarity_matrix) match_indices, match_indicators = self.matcher(similarity_matrix)
# cls_targets, cls_weights, box_weights are not used. box_mask = tf.tile(tf.expand_dims(tf.less_equal(match_indicators, 0), -1),
_, box_targets, _, _ = self.anchor_labeler( [1, 4])
gt_boxes, gt_labels, match_results) box_targets = self.anchor_labeler(gt_boxes, match_indices, box_mask)
box_targets_list = box_list.BoxList(box_targets) box_targets_list = box_list.BoxList(box_targets)
anchor_box_list = box_list.BoxList(flattened_anchor_boxes) anchor_box_list = box_list.BoxList(flattened_anchor_boxes)
box_targets = self.box_coder.encode(box_targets_list, anchor_box_list) box_targets = self.box_coder.encode(box_targets_list, anchor_box_list)
# Zero out the unmatched and ignored regression targets. # Zero out the unmatched and ignored regression targets.
num_matches = match_results.shape.as_list()[0] or tf.shape(match_results)[0] num_matches = match_indices.shape.as_list()[0] or tf.shape(match_indices)[0]
unmatched_ignored_box_targets = tf.zeros([num_matches, 4], dtype=tf.float32) unmatched_ignored_box_targets = tf.zeros([num_matches, 4], dtype=tf.float32)
matched_anchors_mask = tf.greater_equal(match_results, 0) matched_anchors_mask = tf.greater_equal(match_indicators, 0)
# To broadcast matched_anchors_mask to the same shape as # To broadcast matched_anchors_mask to the same shape as
# matched_reg_targets. # matched_reg_targets.
matched_anchors_mask = tf.tile( matched_anchors_mask = tf.tile(
...@@ -290,7 +297,7 @@ class RpnAnchorLabeler(AnchorLabeler): ...@@ -290,7 +297,7 @@ class RpnAnchorLabeler(AnchorLabeler):
unmatched_ignored_box_targets) unmatched_ignored_box_targets)
# score_targets contains the subsampled positive and negative anchors. # score_targets contains the subsampled positive and negative anchors.
score_targets, _, _ = self._get_rpn_samples(match_results) score_targets, _, _ = self._get_rpn_samples(match_indicators)
# Unpacks labels. # Unpacks labels.
score_targets_dict = unpack_targets(score_targets, anchor_boxes) score_targets_dict = unpack_targets(score_targets, anchor_boxes)
......
...@@ -20,55 +20,20 @@ import tensorflow as tf ...@@ -20,55 +20,20 @@ import tensorflow as tf
class AnchorLabeler: class AnchorLabeler:
"""Labeler for dense object detector.""" """Labeler for dense object detector."""
def __init__( def __call__(self, labels, match_indices, mask, mask_val=0.0):
self,
positive_class_weight=1.0,
positive_regression_weight=1.0,
negative_class_weight=1.0,
negative_regression_weight=0.0,
negative_class_label=-1,
ignore_class_label=-2,
negative_regression_label=0.,
ignore_regression_label=0.):
"""Constructs Anchor Labeler.
Args:
positive_class_weight: classification weight to be associated to positive
matched anchor. Defaults to 1.0.
positive_regression_weight: regression weight to be associated to positive
matched anchor. Defaults to 1.0.
negative_class_weight: classification weight to be associated to negative
matched anchor. Default to 1.0
negative_regression_weight: classification weight to be associated to
negative matched anchor. Default to 0.0.
negative_class_label: An integer for classification label to be associated
for negative matched anchor. Defaults to -1.
ignore_class_label: An integer for classification label to be associated
for ignored anchor. Defaults to -2.
negative_regression_label: A float for regression label to be associated
for negative matched anchor. Defaults to 0.
ignore_regression_label: A float for regression label to be associated
for ignored anchor. Defaults to 0.
"""
self.positive_class_weight = positive_class_weight
self.positive_regression_weight = positive_regression_weight
self.negative_class_weight = negative_class_weight
self.negative_regression_weight = negative_regression_weight
self.negative_class_label = negative_class_label
self.ignore_class_label = ignore_class_label
self.negative_regression_label = negative_regression_label
self.ignore_regression_label = ignore_regression_label
def __call__(self, boxes, labels, matches):
"""Labels anchors with ground truth inputs. """Labels anchors with ground truth inputs.
B: batch_size
N: number of groundtruth boxes.
Args: Args:
boxes: A float tensor with shape [N, 4] representing groundtruth boxes. labels: An integer tensor with shape [N, 1] or [B, N, 1] representing
For each row, it stores [y0, x0, y1, x1] for four corners of a box. groundtruth labels.
labels: An integer tensor with shape [N, 1] representing groundtruth match_indices: An integer tensor with shape [N] or [B, N] representing
classes. match label index.
matches: An integer tensor with shape [N] representing match results, must mask: An integer tensor with shape [N] or [B, N] representing match
be -1 for negative matched anchor, and -2 for ignored anchor. labels, e.g., 1 for positive, -1 for negative, -2 for ignore.
mask_val: An integer to fill in for mask.
Returns: Returns:
class_targets: A integer Tensor with shape [num_anchors]. class_targets: A integer Tensor with shape [num_anchors].
...@@ -82,65 +47,43 @@ class AnchorLabeler: ...@@ -82,65 +47,43 @@ class AnchorLabeler:
1.0 for positive matched anchors, and 0.0 for negative and ignored 1.0 for positive matched anchors, and 0.0 for negative and ignored
anchors. anchors.
""" """
if len(labels.shape) <= 2:
class_targets = self._gather_based_on_match( return self._gather_unbatched(labels, match_indices, mask, mask_val)
matches, tf.cast(labels, tf.int32), elif len(labels.shape) == 3:
negative_value=tf.constant([self.negative_class_label], tf.int32), return self._gather_batched(labels, match_indices, mask, mask_val)
ignored_value=tf.constant([self.ignore_class_label], tf.int32))
def _gather_unbatched(self, labels, match_indices, mask, mask_val):
negative_reg_value = tf.constant( """Gather based on unbatched labels and boxes."""
[self.negative_regression_label] * 4, dtype=tf.float32) num_gt_boxes = tf.shape(labels)[0]
ignore_reg_value = tf.constant( masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like(
[self.ignore_regression_label] * 4, dtype=tf.float32) mask, dtype=labels.dtype)
reg_targets = self._gather_based_on_match(
matches, boxes, negative_reg_value, ignore_reg_value) def _assign_when_rows_empty():
return masked_targets
num_gt_boxes = boxes.shape.as_list()[0] or tf.shape(boxes)[0]
def _assign_when_rows_not_empty():
groundtruth_class_weights = self.positive_class_weight * tf.ones( targets = tf.gather(labels, match_indices)
[num_gt_boxes], dtype=tf.float32) return tf.where(mask, masked_targets, targets)
class_weights = self._gather_based_on_match(
matches, groundtruth_class_weights, return tf.cond(tf.greater(num_gt_boxes, 0),
negative_value=self.negative_class_weight, _assign_when_rows_not_empty,
ignored_value=0.) _assign_when_rows_empty)
groundtruth_reg_weights = self.positive_regression_weight * tf.ones( def _gather_batched(self, labels, match_indices, mask, mask_val):
[num_gt_boxes], dtype=tf.float32) """Gather based on batched labels."""
reg_weights = self._gather_based_on_match( batch_size = labels.shape[0]
matches, groundtruth_reg_weights, if batch_size == 1:
negative_value=self.negative_regression_weight, ignored_value=0.) result = self._gather_unbatched(
tf.squeeze(labels, axis=0), tf.squeeze(match_indices, axis=0),
return class_targets, reg_targets, class_weights, reg_weights tf.squeeze(mask, axis=0), mask_val)
return tf.expand_dims(result, axis=0)
def _gather_based_on_match( else:
self, matches, inputs, negative_value, ignored_value): indices_shape = tf.shape(match_indices)
"""Gathers elements from `input_tensor` based on match results. indices_dtype = match_indices.dtype
batch_indices = (tf.expand_dims(
For columns that are matched to a row, gathered_tensor[col] is set to tf.range(indices_shape[0], dtype=indices_dtype), axis=-1) *
input_tensor[match[col]]. For columns that are unmatched, tf.ones([1, indices_shape[-1]], dtype=indices_dtype))
gathered_tensor[col] is set to negative_value. Finally, for columns that gather_nd_indices = tf.stack(
are ignored gathered_tensor[col] is set to ignored_value. [batch_indices, match_indices], axis=-1)
targets = tf.gather_nd(labels, gather_nd_indices)
Note that the input_tensor.shape[1:] must match with unmatched_value.shape return targets
and ignored_value.shape
Args:
matches: A integer tensor with shape [N] representing the
matching results of anchors. (1) match_results[i]>=0,
meaning that column i is matched with row match_results[i].
(2) match_results[i]=-1, meaning that column i is not matched.
(3) match_results[i]=-2, meaning that column i is ignored.
inputs: Tensor to gather values from.
negative_value: Constant tensor value for unmatched columns.
ignored_value: Constant tensor value for ignored columns.
Returns:
gathered_tensor: A tensor containing values gathered from input_tensor.
The shape of the gathered tensor is [match.shape[0]] +
input_tensor.shape[1:].
"""
inputs = tf.concat(
[tf.stack([ignored_value, negative_value]), inputs], axis=0)
gather_indices = tf.maximum(matches + 2, 0)
gathered_tensor = tf.gather(inputs, gather_indices)
return gathered_tensor
...@@ -40,7 +40,9 @@ class BoxMatcher: ...@@ -40,7 +40,9 @@ class BoxMatcher:
self, self,
positive_threshold, positive_threshold,
negative_threshold=None, negative_threshold=None,
force_match_for_each_row=False, force_match_for_each_col=False,
negative_lower_than_ignore=True,
positive_value=1,
negative_value=-1, negative_value=-1,
ignore_value=-2): ignore_value=-2):
"""Construct BoxMatcher. """Construct BoxMatcher.
...@@ -53,11 +55,15 @@ class BoxMatcher: ...@@ -53,11 +55,15 @@ class BoxMatcher:
sim < negative_threshold or sim < negative_threshold or
positive_threshold > sim >= negative_threshold. positive_threshold > sim >= negative_threshold.
Defaults to positive_threshold when set to None. Defaults to positive_threshold when set to None.
force_match_for_each_row: If True, ensures that each row is matched to force_match_for_each_col: If True, ensures that each column is matched to
at least one column (which is not guaranteed otherwise if the at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. positive_threshold is high). Defaults to False.
negative_value: An integer to fill for negative matches. negative_lower_than_ignore: If True, the threshold is
ignore_value: An integer to fill for ignored matches. positive|ignore|negative, else positive|negative|ignore. Defaults to
True.
positive_value: An integer to fill for positive match labels.
negative_value: An integer to fill for negative match labels.
ignore_value: An integer to fill for ignored match labels.
Raises: Raises:
ValueError: If negative_threshold > positive_threshold. ValueError: If negative_threshold > positive_threshold.
...@@ -71,9 +77,11 @@ class BoxMatcher: ...@@ -71,9 +77,11 @@ class BoxMatcher:
'to positive_threshold') 'to positive_threshold')
self._negative_threshold = negative_threshold self._negative_threshold = negative_threshold
self._positive_value = positive_value
self._negative_value = negative_value self._negative_value = negative_value
self._ignore_value = ignore_value self._ignore_value = ignore_value
self._force_match_for_each_row = force_match_for_each_row self._force_match_for_each_col = force_match_for_each_col
self._negative_lower_than_ignore = negative_lower_than_ignore
def __call__(self, similarity_matrix): def __call__(self, similarity_matrix):
"""Tries to match each column of the similarity matrix to a row. """Tries to match each column of the similarity matrix to a row.
...@@ -83,11 +91,20 @@ class BoxMatcher: ...@@ -83,11 +91,20 @@ class BoxMatcher:
similarity metric. similarity metric.
Returns: Returns:
A integer tensor with corresponding match indices for each of M columns, A integer tensor of shape [N] with corresponding match indices for each
for positive match, the match result will be the corresponding row index, of M columns, for positive match, the match result will be the
for negative match, the match will be `negative_value`, for ignored match, corresponding row index, for negative match, the match will be
the match result will be `ignore_value`. `negative_value`, for ignored match, the match result will be
`ignore_value`.
""" """
squeeze_result = False
if len(similarity_matrix.shape) == 2:
squeeze_result = True
similarity_matrix = tf.expand_dims(similarity_matrix, axis=0)
static_shape = similarity_matrix.shape.as_list()
num_rows = static_shape[1] or tf.shape(similarity_matrix)[1]
batch_size = static_shape[0] or tf.shape(similarity_matrix)[0]
def _match_when_rows_are_empty(): def _match_when_rows_are_empty():
"""Performs matching when the rows of similarity matrix are empty. """Performs matching when the rows of similarity matrix are empty.
...@@ -98,9 +115,11 @@ class BoxMatcher: ...@@ -98,9 +115,11 @@ class BoxMatcher:
Returns: Returns:
matches: int32 tensor indicating the row each column matches to. matches: int32 tensor indicating the row each column matches to.
""" """
static_shape = similarity_matrix.shape.as_list() with tf.name_scope('empty_gt_boxes'):
num_cols = static_shape[1] or tf.shape(similarity_matrix)[1] matches = tf.zeros([batch_size, num_rows], dtype=tf.int32)
return -1 * tf.ones([num_cols], dtype=tf.int32) match_labels = self._negative_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32)
return matches, match_labels
def _match_when_rows_are_non_empty(): def _match_when_rows_are_non_empty():
"""Performs matching when the rows of similarity matrix are non empty. """Performs matching when the rows of similarity matrix are non empty.
...@@ -109,50 +128,75 @@ class BoxMatcher: ...@@ -109,50 +128,75 @@ class BoxMatcher:
matches: int32 tensor indicating the row each column matches to. matches: int32 tensor indicating the row each column matches to.
""" """
# Matches for each column # Matches for each column
matches = tf.argmax(input=similarity_matrix, axis=0, output_type=tf.int32) with tf.name_scope('non_empty_gt_boxes'):
matches = tf.argmax(similarity_matrix, axis=-1, output_type=tf.int32)
# Deal with matched and unmatched threshold
if self._positive_threshold is not None:
# Get logical indices of ignored and unmatched columns as tf.int64 # Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=0) matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
below_negative_threshold = tf.greater(self._negative_threshold, matched_labels = self._positive_value * tf.ones(
matched_vals) [batch_size, num_rows], tf.int32)
positive_threshold = tf.cast(
self._positive_threshold, matched_vals.dtype)
negative_threshold = tf.cast(
self._negative_threshold, matched_vals.dtype)
below_negative_threshold = tf.greater(negative_threshold, matched_vals)
between_thresholds = tf.logical_and( between_thresholds = tf.logical_and(
tf.greater_equal(matched_vals, self._negative_threshold), tf.greater_equal(matched_vals, negative_threshold),
tf.greater(self._positive_threshold, matched_vals)) tf.greater(positive_threshold, matched_vals))
matches = self._set_values_using_indicator(matches, if self._negative_lower_than_ignore:
below_negative_threshold, matched_labels = self._set_values_using_indicator(
self._negative_value) matched_labels, below_negative_threshold, self._negative_value)
matches = self._set_values_using_indicator(matches, matched_labels = self._set_values_using_indicator(
between_thresholds, matched_labels, between_thresholds, self._ignore_value)
self._ignore_value) else:
matched_labels = self._set_values_using_indicator(
if self._force_match_for_each_row: matched_labels, below_negative_threshold, self._ignore_value)
num_gt_boxes = similarity_matrix.shape.as_list()[1] or tf.shape( matched_labels = self._set_values_using_indicator(
similarity_matrix)[1] matched_labels, between_thresholds, self._negative_value)
if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best
# matching row (anchor).
force_match_column_ids = tf.argmax( force_match_column_ids = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32) input=similarity_matrix, axis=1, output_type=tf.int32)
# [batch_size, M, N]
force_match_column_indicators = tf.one_hot( force_match_column_indicators = tf.one_hot(
force_match_column_ids, depth=num_gt_boxes) force_match_column_ids, depth=num_rows)
# [batch_size, N], for each row (anchor), find the largest column
# index for groundtruth box
force_match_row_ids = tf.argmax( force_match_row_ids = tf.argmax(
input=force_match_column_indicators, axis=0, output_type=tf.int32) input=force_match_column_indicators, axis=1, output_type=tf.int32)
# [batch_size, N]
force_match_column_mask = tf.cast( force_match_column_mask = tf.cast(
tf.reduce_max(force_match_column_indicators, axis=0), tf.reduce_max(force_match_column_indicators, axis=1),
tf.bool) tf.bool)
# [batch_size, N]
final_matches = tf.where(force_match_column_mask, force_match_row_ids, final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches) matches)
return final_matches final_matched_labels = tf.where(
force_match_column_mask,
self._positive_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32),
matched_labels)
return final_matches, final_matched_labels
else: else:
return matches return matches, matched_labels
num_gt_boxes = similarity_matrix.shape.as_list()[0] or tf.shape( num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[0] similarity_matrix)[-1]
return tf.cond( result_match, result_match_labels = tf.cond(
pred=tf.greater(num_gt_boxes, 0), pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty, true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty) false_fn=_match_when_rows_are_empty)
if squeeze_result:
result_match = tf.squeeze(result_match, axis=0)
result_match_labels = tf.squeeze(result_match_labels, axis=0)
return result_match, result_match_labels
def _set_values_using_indicator(self, x, indicator, val): def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val. """Set the indicated fields of x to val.
......
...@@ -20,57 +20,78 @@ import tensorflow as tf ...@@ -20,57 +20,78 @@ import tensorflow as tf
def area(box): def area(box):
"""Computes area of boxes. """Computes area of boxes.
B: batch_size
N: number of boxes
Args: Args:
box: a float Tensor with [N, 4]. box: a float Tensor with [N, 4], or [B, N, 4].
Returns: Returns:
a float tensor with [N]. a float Tensor with [N], or [B, N]
""" """
with tf.name_scope('Area'): with tf.name_scope('Area'):
y_min, x_min, y_max, x_max = tf.split( y_min, x_min, y_max, x_max = tf.split(
value=box, num_or_size_splits=4, axis=1) value=box, num_or_size_splits=4, axis=-1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1]) return tf.squeeze((y_max - y_min) * (x_max - x_min), axis=-1)
def intersection(box1, box2): def intersection(gt_boxes, boxes):
"""Compute pairwise intersection areas between boxes. """Compute pairwise intersection areas between boxes.
B: batch_size
N: number of groundtruth boxes.
M: number of anchor boxes.
Args: Args:
box1: a float Tensor with [N, 4]. gt_boxes: a float Tensor with [N, 4], or [B, N, 4]
box2: a float Tensor with [M, 4]. boxes: a float Tensor with [M, 4], or [B, M, 4]
Returns: Returns:
a float tensor with shape [N, M] representing pairwise intersections a float Tensor with shape [N, M] or [B, N, M] representing pairwise
intersections.
""" """
with tf.name_scope('Intersection'): with tf.name_scope('Intersection'):
y_min1, x_min1, y_max1, x_max1 = tf.split( y_min1, x_min1, y_max1, x_max1 = tf.split(
value=box1, num_or_size_splits=4, axis=1) value=gt_boxes, num_or_size_splits=4, axis=-1)
y_min2, x_min2, y_max2, x_max2 = tf.split( y_min2, x_min2, y_max2, x_max2 = tf.split(
value=box2, num_or_size_splits=4, axis=1) value=boxes, num_or_size_splits=4, axis=-1)
y_min_max = tf.minimum(y_max1, tf.transpose(a=y_max2))
y_max_min = tf.maximum(y_min1, tf.transpose(a=y_min2)) boxes_rank = len(boxes.shape)
intersect_heights = tf.maximum(0.0, y_min_max - y_max_min) perm = [1, 0] if boxes_rank == 2 else [0, 2, 1]
x_min_max = tf.minimum(x_max1, tf.transpose(a=x_max2)) # [N, M] or [B, N, M]
x_max_min = tf.maximum(x_min1, tf.transpose(a=x_min2)) y_min_max = tf.minimum(y_max1, tf.transpose(y_max2, perm))
intersect_widths = tf.maximum(0.0, x_min_max - x_max_min) y_max_min = tf.maximum(y_min1, tf.transpose(y_min2, perm))
x_min_max = tf.minimum(x_max1, tf.transpose(x_max2, perm))
x_max_min = tf.maximum(x_min1, tf.transpose(x_min2, perm))
intersect_heights = y_min_max - y_max_min
intersect_widths = x_min_max - x_max_min
zeros_t = tf.cast(0, intersect_heights.dtype)
intersect_heights = tf.maximum(zeros_t, intersect_heights)
intersect_widths = tf.maximum(zeros_t, intersect_widths)
return intersect_heights * intersect_widths return intersect_heights * intersect_widths
def iou(box1, box2): def iou(gt_boxes, boxes):
"""Computes pairwise intersection-over-union between box collections. """Computes pairwise intersection-over-union between box collections.
Args: Args:
box1: a float Tensor with [N, 4]. gt_boxes: a float Tensor with [N, 4].
box2: a float Tensor with [M, 4]. boxes: a float Tensor with [M, 4].
Returns: Returns:
a tensor with shape [N, M] representing pairwise iou scores. a Tensor with shape [N, M] representing pairwise iou scores.
""" """
intersections = intersection(box1, box2) with tf.name_scope('IOU'):
areas1 = area(box1) intersections = intersection(gt_boxes, boxes)
areas2 = area(box2) gt_boxes_areas = area(gt_boxes)
unions = ( boxes_areas = area(boxes)
tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections) boxes_rank = len(boxes_areas.shape)
boxes_axis = 1 if (boxes_rank == 2) else 0
gt_boxes_areas = tf.expand_dims(gt_boxes_areas, -1)
boxes_areas = tf.expand_dims(boxes_areas, boxes_axis)
unions = gt_boxes_areas + boxes_areas
unions = unions - intersections
return tf.where( return tf.where(
tf.equal(intersections, 0.0), tf.zeros_like(intersections), tf.equal(intersections, 0.0), tf.zeros_like(intersections),
tf.truediv(intersections, unions)) tf.truediv(intersections, unions))
...@@ -84,19 +105,37 @@ class IouSimilarity(): ...@@ -84,19 +105,37 @@ class IouSimilarity():
def __call__(self, groundtruth_boxes, anchors): def __call__(self, groundtruth_boxes, anchors):
"""Compute pairwise IOU similarity between ground truth boxes and anchors. """Compute pairwise IOU similarity between ground truth boxes and anchors.
B: batch_size
N: Number of groundtruth boxes.
M: Number of anchor boxes.
Args: Args:
groundtruth_boxes: a float Tensor with N boxes. groundtruth_boxes: a float Tensor with M boxes.
anchors: a float Tensor with M boxes. anchors: a float Tensor with N boxes.
Returns: Returns:
A tensor with shape [N, M] representing pairwise iou scores. A Tensor with shape [M, N] or [B, M, N] representing pairwise
iou scores, anchor per row and groundtruth_box per colulmn.
Input shape: Input shape:
groundtruth_boxes: [N, 4] groundtruth_boxes: [N, 4], or [B, N, 4]
anchors: [M, 4] anchors: [M, 4], or [B, M, 4]
Output shape: Output shape:
[N, M] [M, N], or [B, M, N]
""" """
with tf.name_scope('IOU'): groundtruth_rank = len(groundtruth_boxes.shape)
anchor_rank = len(anchors.shape)
if groundtruth_rank < 2 or groundtruth_rank > 3:
raise ValueError('`groudtruth_boxes` must be rank 2 or 3, got {}'.format(
groundtruth_rank))
if anchor_rank < 2 or anchor_rank > 3:
raise ValueError('`anchors` must be rank 2 or 3, got {}'.format(
anchor_rank))
if groundtruth_rank < anchor_rank:
raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is '
'batched is not a valid use case, got groundtruth_box '
'rank {}, and anchors rank {}'.format(
groundtruth_rank, anchor_rank))
return iou(groundtruth_boxes, anchors) return iou(groundtruth_boxes, anchors)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment