"git@developer.sourcefind.cn:wangsen/rocm_bandwidth_test.git" did not exist on "945a8bef5912dade2b8d6ba924c60c12dcc663ed"
Commit f82ade47 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 335128787
parent 5a533fd4
...@@ -61,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -61,6 +61,7 @@ class ROISampler(tf.keras.layers.Layer):
'background_iou_low_threshold': background_iou_low_threshold, 'background_iou_low_threshold': background_iou_low_threshold,
} }
self._sim_calc = keras_cv.ops.IouSimilarity()
self._box_matcher = keras_cv.ops.BoxMatcher( self._box_matcher = keras_cv.ops.BoxMatcher(
thresholds=[ thresholds=[
background_iou_low_threshold, background_iou_high_threshold, background_iou_low_threshold, background_iou_high_threshold,
...@@ -114,7 +115,12 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -114,7 +115,12 @@ class ROISampler(tf.keras.layers.Layer):
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype) gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
boxes = tf.concat([boxes, gt_boxes], axis=1) boxes = tf.concat([boxes, gt_boxes], axis=1)
similarity_matrix = box_ops.bbox_overlap(boxes, gt_boxes) boxes_invalid_mask = tf.less(
tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0)
gt_invalid_mask = tf.less(
tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask,
gt_invalid_mask)
matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix) matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
positive_matches = tf.greater_equal(match_indicators, 0) positive_matches = tf.greater_equal(match_indicators, 0)
negative_matches = tf.equal(match_indicators, -1) negative_matches = tf.equal(match_indicators, -1)
......
...@@ -97,12 +97,15 @@ def iou(gt_boxes, boxes): ...@@ -97,12 +97,15 @@ def iou(gt_boxes, boxes):
tf.truediv(intersections, unions)) tf.truediv(intersections, unions))
class IouSimilarity(): class IouSimilarity:
"""Class to compute similarity based on Intersection over Union (IOU) metric. """Class to compute similarity based on Intersection over Union (IOU) metric.
""" """
def __call__(self, groundtruth_boxes, anchors): def __init__(self, mask_val=-1):
self.mask_val = mask_val
def __call__(self, boxes_1, boxes_2, boxes_1_masks=None, boxes_2_masks=None):
"""Compute pairwise IOU similarity between ground truth boxes and anchors. """Compute pairwise IOU similarity between ground truth boxes and anchors.
B: batch_size B: batch_size
...@@ -110,32 +113,52 @@ class IouSimilarity(): ...@@ -110,32 +113,52 @@ class IouSimilarity():
M: Number of anchor boxes. M: Number of anchor boxes.
Args: Args:
groundtruth_boxes: a float Tensor with M boxes. boxes_1: a float Tensor with M or B * M boxes.
anchors: a float Tensor with N boxes. boxes_2: a float Tensor with N or B * N boxes, the rank must be less than
or equal to rank of `boxes_1`.
boxes_1_masks: a boolean Tensor with M or B * M boxes. Optional.
boxes_2_masks: a boolean Tensor with N or B * N boxes. Optional.
Returns: Returns:
A Tensor with shape [M, N] or [B, M, N] representing pairwise A Tensor with shape [M, N] or [B, M, N] representing pairwise
iou scores, anchor per row and groundtruth_box per colulmn. iou scores, anchor per row and groundtruth_box per colulmn.
Input shape: Input shape:
groundtruth_boxes: [N, 4], or [B, N, 4] boxes_1: [N, 4], or [B, N, 4]
anchors: [M, 4], or [B, M, 4] boxes_2: [M, 4], or [B, M, 4]
boxes_1_masks: [N, 1], or [B, N, 1]
boxes_2_masks: [M, 1], or [B, M, 1]
Output shape: Output shape:
[M, N], or [B, M, N] [M, N], or [B, M, N]
""" """
groundtruth_rank = len(groundtruth_boxes.shape) boxes_1_rank = len(boxes_1.shape)
anchor_rank = len(anchors.shape) boxes_2_rank = len(boxes_2.shape)
if groundtruth_rank < 2 or groundtruth_rank > 3: if boxes_1_rank < 2 or boxes_1_rank > 3:
raise ValueError('`groudtruth_boxes` must be rank 2 or 3, got {}'.format( raise ValueError(
groundtruth_rank)) '`groudtruth_boxes` must be rank 2 or 3, got {}'.format(boxes_1_rank))
if anchor_rank < 2 or anchor_rank > 3: if boxes_2_rank < 2 or boxes_2_rank > 3:
raise ValueError('`anchors` must be rank 2 or 3, got {}'.format( raise ValueError(
anchor_rank)) '`anchors` must be rank 2 or 3, got {}'.format(boxes_2_rank))
if groundtruth_rank < anchor_rank: if boxes_1_rank < boxes_2_rank:
raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is ' raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is '
'batched is not a valid use case, got groundtruth_box ' 'batched is not a valid use case, got groundtruth_box '
'rank {}, and anchors rank {}'.format( 'rank {}, and anchors rank {}'.format(
groundtruth_rank, anchor_rank)) boxes_1_rank, boxes_2_rank))
return iou(groundtruth_boxes, anchors) result = iou(boxes_1, boxes_2)
if boxes_1_masks is None and boxes_2_masks is None:
return result
background_mask = None
mask_val_t = tf.cast(self.mask_val, result.dtype) * tf.ones_like(result)
perm = [1, 0] if boxes_2_rank == 2 else [0, 2, 1]
if boxes_1_masks is not None and boxes_2_masks is not None:
background_mask = tf.logical_or(boxes_1_masks,
tf.transpose(boxes_2_masks, perm))
elif boxes_1_masks is not None:
background_mask = boxes_1_masks
else:
background_mask = tf.logical_or(
tf.zeros(tf.shape(boxes_2)[:-1], dtype=tf.bool),
tf.transpose(boxes_2_masks, perm))
return tf.where(background_mask, mask_val_t, result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment