Commit 95a0f9ac authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 335079988
parent 32c79ea6
......@@ -17,7 +17,7 @@
# Import libraries
import tensorflow as tf
from official.vision.beta.modeling.layers import box_matcher
from official.vision import keras_cv
from official.vision.beta.modeling.layers import box_sampler
from official.vision.beta.ops import box_ops
......@@ -60,10 +60,15 @@ class ROISampler(tf.keras.layers.Layer):
'background_iou_high_threshold': background_iou_high_threshold,
'background_iou_low_threshold': background_iou_low_threshold,
}
self._matcher = box_matcher.BoxMatcher(
foreground_iou_threshold,
background_iou_high_threshold,
background_iou_low_threshold)
self._box_matcher = keras_cv.ops.BoxMatcher(
thresholds=[
background_iou_low_threshold, background_iou_high_threshold,
foreground_iou_threshold
],
indicators=[-3, -1, -2, 1])
self._anchor_labeler = keras_cv.ops.AnchorLabeler()
self._sampler = box_sampler.BoxSampler(
num_sampled_rois, foreground_fraction)
super(ROISampler, self).__init__(**kwargs)
......@@ -109,9 +114,30 @@ class ROISampler(tf.keras.layers.Layer):
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
boxes = tf.concat([boxes, gt_boxes], axis=1)
(matched_gt_boxes, matched_gt_classes, matched_gt_indices,
positive_matches, negative_matches, ignored_matches) = (
self._matcher(boxes, gt_boxes, gt_classes))
similarity_matrix = box_ops.bbox_overlap(boxes, gt_boxes)
matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
positive_matches = tf.greater_equal(match_indicators, 0)
negative_matches = tf.equal(match_indicators, -1)
ignored_matches = tf.equal(match_indicators, -2)
invalid_matches = tf.equal(match_indicators, -3)
background_mask = tf.expand_dims(
tf.logical_or(negative_matches, invalid_matches), -1)
gt_classes = tf.expand_dims(gt_classes, axis=-1)
matched_gt_classes = self._anchor_labeler(gt_classes, matched_gt_indices,
background_mask)
matched_gt_classes = tf.where(background_mask,
tf.zeros_like(matched_gt_classes),
matched_gt_classes)
matched_gt_classes = tf.squeeze(matched_gt_classes, axis=-1)
matched_gt_boxes = self._anchor_labeler(gt_boxes, matched_gt_indices,
tf.tile(background_mask, [1, 1, 4]))
matched_gt_boxes = tf.where(background_mask,
tf.zeros_like(matched_gt_boxes),
matched_gt_boxes)
matched_gt_indices = tf.where(
tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
matched_gt_indices)
sampled_indices = self._sampler(
positive_matches, negative_matches, ignored_matches)
......
......@@ -135,8 +135,8 @@ class AnchorLabeler(object):
self.similarity_calc = keras_cv.ops.IouSimilarity()
self.anchor_labeler = keras_cv.ops.AnchorLabeler()
self.matcher = keras_cv.ops.BoxMatcher(
positive_threshold=match_threshold,
negative_threshold=unmatched_threshold,
thresholds=[unmatched_threshold, match_threshold],
indicators=[-1, -2, 1],
force_match_for_each_col=True)
self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
......
......@@ -28,60 +28,51 @@ class BoxMatcher:
To support object detection target assignment this class enables setting both
positive_threshold (upper threshold) and negative_threshold (lower thresholds)
defining three categories of similarity which define whether examples are
positive, negative, or ignored:
(1) similarity >= positive_threshold: Highest similarity. Matched/Positive!
(2) positive_threshold > similarity >= negative_threshold: Medium similarity.
This is Ignored.
(3) negative_threshold > similarity: Lowest similarity for Negative Match.
For ignored matches this class sets the values in the Match object to -2.
positive, negative, or ignored, for example:
(1) thresholds=[negative_threshold, positive_threshold], and
indicators=[negative_value, ignore_value, positive_value]: The similarity
metrics below negative_threshold will be assigned with negative_value,
the metrics between negative_threshold and positive_threshold will be
assigned ignore_value, and the metrics above positive_threshold will be
assigned positive_value.
(2) thresholds=[negative_threshold, positive_threshold], and
indicators=[ignore_value, negative_value, positive_value]: The similarity
metric below negative_threshold will be assigned with ignore_value,
the metrics between negative_threshold and positive_threshold will be
assigned negative_value, and the metrics above positive_threshold will be
assigned positive_value.
"""
def __init__(
self,
positive_threshold,
negative_threshold=None,
force_match_for_each_col=False,
negative_lower_than_ignore=True,
positive_value=1,
negative_value=-1,
ignore_value=-2):
def __init__(self, thresholds, indicators, force_match_for_each_col=False):
"""Construct BoxMatcher.
Args:
positive_threshold: Threshold for positive matches. Positive if
sim >= positive_threshold, where sim is the maximum value of the
similarity matrix for a given column. Set to None for no threshold.
negative_threshold: Threshold for negative matches. Negative if
sim < negative_threshold or
positive_threshold > sim >= negative_threshold.
Defaults to positive_threshold when set to None.
thresholds: A list of thresholds to classify boxes into
different buckets. The list needs to be sorted, and will be prepended
with -Inf and appended with +Inf.
indicators: A list of values to assign for each bucket. len(`indicators`)
must equal to len(`thresholds`) + 1.
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False.
negative_lower_than_ignore: If True, the threshold is
positive|ignore|negative, else positive|negative|ignore. Defaults to
True.
positive_value: An integer to fill for positive match labels.
negative_value: An integer to fill for negative match labels.
ignore_value: An integer to fill for ignored match labels.
positive_threshold is high). Defaults to False. If True, all force
matched row will be assigned to `indicators[-1]`.
Raises:
ValueError: If negative_threshold > positive_threshold.
ValueError: If `threshold` not sorted,
or len(indicators) != len(threshold) + 1
"""
self._positive_threshold = positive_threshold
if negative_threshold is None:
self._negative_threshold = positive_threshold
else:
if negative_threshold > positive_threshold:
raise ValueError('negative_threshold needs to be smaller or equal'
'to positive_threshold')
self._negative_threshold = negative_threshold
self._positive_value = positive_value
self._negative_value = negative_value
self._ignore_value = ignore_value
if not all([lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:])]):
raise ValueError('`threshold` must be sorted, got {}'.format(thresholds))
self.indicators = indicators
if len(indicators) != len(thresholds) + 1:
raise ValueError('len(`indicators`) must be len(`thresholds`) + 1, got '
'indicators {}, thresholds {}'.format(
indicators, thresholds))
thresholds = thresholds[:]
thresholds.insert(0, -float('inf'))
thresholds.append(float('inf'))
self.thresholds = thresholds
self._force_match_for_each_col = force_match_for_each_col
self._negative_lower_than_ignore = negative_lower_than_ignore
def __call__(self, similarity_matrix):
"""Tries to match each column of the similarity matrix to a row.
......@@ -117,8 +108,7 @@ class BoxMatcher:
"""
with tf.name_scope('empty_gt_boxes'):
matches = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_labels = self._negative_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32)
match_labels = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matches, match_labels
def _match_when_rows_are_non_empty():
......@@ -133,28 +123,18 @@ class BoxMatcher:
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_labels = self._positive_value * tf.ones(
[batch_size, num_rows], tf.int32)
positive_threshold = tf.cast(
self._positive_threshold, matched_vals.dtype)
negative_threshold = tf.cast(
self._negative_threshold, matched_vals.dtype)
below_negative_threshold = tf.greater(negative_threshold, matched_vals)
between_thresholds = tf.logical_and(
tf.greater_equal(matched_vals, negative_threshold),
tf.greater(positive_threshold, matched_vals))
if self._negative_lower_than_ignore:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._negative_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._ignore_value)
else:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._ignore_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._negative_value)
matched_indicators = tf.zeros([batch_size, num_rows], tf.int32)
match_dtype = matched_vals.dtype
for (ind, low, high) in zip(self.indicators, self.thresholds[:-1],
self.thresholds[1:]):
low_threshold = tf.cast(low, match_dtype)
high_threshold = tf.cast(high, match_dtype)
mask = tf.logical_and(
tf.greater_equal(matched_vals, low_threshold),
tf.less(matched_vals, high_threshold))
matched_indicators = self._set_values_using_indicator(
matched_indicators, mask, ind)
if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best
......@@ -175,27 +155,26 @@ class BoxMatcher:
# [batch_size, N]
final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches)
final_matched_labels = tf.where(
force_match_column_mask,
self._positive_value * tf.ones(
[batch_size, num_rows], dtype=tf.int32),
matched_labels)
return final_matches, final_matched_labels
final_matched_indicators = tf.where(
force_match_column_mask, self.indicators[-1] *
tf.ones([batch_size, num_rows], dtype=tf.int32),
matched_indicators)
return final_matches, final_matched_indicators
else:
return matches, matched_labels
return matches, matched_indicators
num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[-1]
result_match, result_match_labels = tf.cond(
result_match, result_matched_indicators = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty)
if squeeze_result:
result_match = tf.squeeze(result_match, axis=0)
result_match_labels = tf.squeeze(result_match_labels, axis=0)
result_matched_indicators = tf.squeeze(result_matched_indicators, axis=0)
return result_match, result_match_labels
return result_match, result_matched_indicators
def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment