Commit f82ade47 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 335128787
parent 5a533fd4
......@@ -61,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer):
'background_iou_low_threshold': background_iou_low_threshold,
}
self._sim_calc = keras_cv.ops.IouSimilarity()
self._box_matcher = keras_cv.ops.BoxMatcher(
thresholds=[
background_iou_low_threshold, background_iou_high_threshold,
......@@ -114,7 +115,12 @@ class ROISampler(tf.keras.layers.Layer):
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
boxes = tf.concat([boxes, gt_boxes], axis=1)
similarity_matrix = box_ops.bbox_overlap(boxes, gt_boxes)
boxes_invalid_mask = tf.less(
tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0)
gt_invalid_mask = tf.less(
tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask,
gt_invalid_mask)
matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
positive_matches = tf.greater_equal(match_indicators, 0)
negative_matches = tf.equal(match_indicators, -1)
......
......@@ -97,12 +97,15 @@ def iou(gt_boxes, boxes):
tf.truediv(intersections, unions))
class IouSimilarity():
class IouSimilarity:
"""Class to compute similarity based on Intersection over Union (IOU) metric.
"""
def __call__(self, groundtruth_boxes, anchors):
def __init__(self, mask_val=-1):
self.mask_val = mask_val
def __call__(self, boxes_1, boxes_2, boxes_1_masks=None, boxes_2_masks=None):
"""Compute pairwise IOU similarity between ground truth boxes and anchors.
B: batch_size
......@@ -110,32 +113,52 @@ class IouSimilarity():
M: Number of anchor boxes.
Args:
groundtruth_boxes: a float Tensor with M boxes.
anchors: a float Tensor with N boxes.
boxes_1: a float Tensor with M or B * M boxes.
boxes_2: a float Tensor with N or B * N boxes, the rank must be less than
or equal to rank of `boxes_1`.
boxes_1_masks: a boolean Tensor with M or B * M boxes. Optional.
boxes_2_masks: a boolean Tensor with N or B * N boxes. Optional.
Returns:
A Tensor with shape [M, N] or [B, M, N] representing pairwise
iou scores, anchor per row and groundtruth_box per colulmn.
Input shape:
groundtruth_boxes: [N, 4], or [B, N, 4]
anchors: [M, 4], or [B, M, 4]
boxes_1: [N, 4], or [B, N, 4]
boxes_2: [M, 4], or [B, M, 4]
boxes_1_masks: [N, 1], or [B, N, 1]
boxes_2_masks: [M, 1], or [B, M, 1]
Output shape:
[M, N], or [B, M, N]
"""
groundtruth_rank = len(groundtruth_boxes.shape)
anchor_rank = len(anchors.shape)
if groundtruth_rank < 2 or groundtruth_rank > 3:
raise ValueError('`groudtruth_boxes` must be rank 2 or 3, got {}'.format(
groundtruth_rank))
if anchor_rank < 2 or anchor_rank > 3:
raise ValueError('`anchors` must be rank 2 or 3, got {}'.format(
anchor_rank))
if groundtruth_rank < anchor_rank:
boxes_1_rank = len(boxes_1.shape)
boxes_2_rank = len(boxes_2.shape)
if boxes_1_rank < 2 or boxes_1_rank > 3:
raise ValueError(
'`groudtruth_boxes` must be rank 2 or 3, got {}'.format(boxes_1_rank))
if boxes_2_rank < 2 or boxes_2_rank > 3:
raise ValueError(
'`anchors` must be rank 2 or 3, got {}'.format(boxes_2_rank))
if boxes_1_rank < boxes_2_rank:
raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is '
'batched is not a valid use case, got groundtruth_box '
'rank {}, and anchors rank {}'.format(
groundtruth_rank, anchor_rank))
return iou(groundtruth_boxes, anchors)
boxes_1_rank, boxes_2_rank))
result = iou(boxes_1, boxes_2)
if boxes_1_masks is None and boxes_2_masks is None:
return result
background_mask = None
mask_val_t = tf.cast(self.mask_val, result.dtype) * tf.ones_like(result)
perm = [1, 0] if boxes_2_rank == 2 else [0, 2, 1]
if boxes_1_masks is not None and boxes_2_masks is not None:
background_mask = tf.logical_or(boxes_1_masks,
tf.transpose(boxes_2_masks, perm))
elif boxes_1_masks is not None:
background_mask = boxes_1_masks
else:
background_mask = tf.logical_or(
tf.zeros(tf.shape(boxes_2)[:-1], dtype=tf.bool),
tf.transpose(boxes_2_masks, perm))
return tf.where(background_mask, mask_val_t, result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment