"...git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "3472bc29a0548050dccabbd4c81617c953ef900d"
Commit 95a0f9ac authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 335079988
parent 32c79ea6
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling.layers import box_matcher from official.vision import keras_cv
from official.vision.beta.modeling.layers import box_sampler from official.vision.beta.modeling.layers import box_sampler
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
...@@ -60,10 +60,15 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -60,10 +60,15 @@ class ROISampler(tf.keras.layers.Layer):
'background_iou_high_threshold': background_iou_high_threshold, 'background_iou_high_threshold': background_iou_high_threshold,
'background_iou_low_threshold': background_iou_low_threshold, 'background_iou_low_threshold': background_iou_low_threshold,
} }
self._matcher = box_matcher.BoxMatcher(
foreground_iou_threshold, self._box_matcher = keras_cv.ops.BoxMatcher(
background_iou_high_threshold, thresholds=[
background_iou_low_threshold) background_iou_low_threshold, background_iou_high_threshold,
foreground_iou_threshold
],
indicators=[-3, -1, -2, 1])
self._anchor_labeler = keras_cv.ops.AnchorLabeler()
self._sampler = box_sampler.BoxSampler( self._sampler = box_sampler.BoxSampler(
num_sampled_rois, foreground_fraction) num_sampled_rois, foreground_fraction)
super(ROISampler, self).__init__(**kwargs) super(ROISampler, self).__init__(**kwargs)
...@@ -109,9 +114,30 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -109,9 +114,30 @@ class ROISampler(tf.keras.layers.Layer):
gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype) gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
boxes = tf.concat([boxes, gt_boxes], axis=1) boxes = tf.concat([boxes, gt_boxes], axis=1)
(matched_gt_boxes, matched_gt_classes, matched_gt_indices, similarity_matrix = box_ops.bbox_overlap(boxes, gt_boxes)
positive_matches, negative_matches, ignored_matches) = ( matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
self._matcher(boxes, gt_boxes, gt_classes)) positive_matches = tf.greater_equal(match_indicators, 0)
negative_matches = tf.equal(match_indicators, -1)
ignored_matches = tf.equal(match_indicators, -2)
invalid_matches = tf.equal(match_indicators, -3)
background_mask = tf.expand_dims(
tf.logical_or(negative_matches, invalid_matches), -1)
gt_classes = tf.expand_dims(gt_classes, axis=-1)
matched_gt_classes = self._anchor_labeler(gt_classes, matched_gt_indices,
background_mask)
matched_gt_classes = tf.where(background_mask,
tf.zeros_like(matched_gt_classes),
matched_gt_classes)
matched_gt_classes = tf.squeeze(matched_gt_classes, axis=-1)
matched_gt_boxes = self._anchor_labeler(gt_boxes, matched_gt_indices,
tf.tile(background_mask, [1, 1, 4]))
matched_gt_boxes = tf.where(background_mask,
tf.zeros_like(matched_gt_boxes),
matched_gt_boxes)
matched_gt_indices = tf.where(
tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
matched_gt_indices)
sampled_indices = self._sampler( sampled_indices = self._sampler(
positive_matches, negative_matches, ignored_matches) positive_matches, negative_matches, ignored_matches)
......
...@@ -135,8 +135,8 @@ class AnchorLabeler(object): ...@@ -135,8 +135,8 @@ class AnchorLabeler(object):
self.similarity_calc = keras_cv.ops.IouSimilarity() self.similarity_calc = keras_cv.ops.IouSimilarity()
self.anchor_labeler = keras_cv.ops.AnchorLabeler() self.anchor_labeler = keras_cv.ops.AnchorLabeler()
self.matcher = keras_cv.ops.BoxMatcher( self.matcher = keras_cv.ops.BoxMatcher(
positive_threshold=match_threshold, thresholds=[unmatched_threshold, match_threshold],
negative_threshold=unmatched_threshold, indicators=[-1, -2, 1],
force_match_for_each_col=True) force_match_for_each_col=True)
self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder() self.box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
......
...@@ -28,60 +28,51 @@ class BoxMatcher: ...@@ -28,60 +28,51 @@ class BoxMatcher:
To support object detection target assignment this class enables setting both To support object detection target assignment this class enables setting both
positive_threshold (upper threshold) and negative_threshold (lower thresholds) positive_threshold (upper threshold) and negative_threshold (lower thresholds)
defining three categories of similarity which define whether examples are defining three categories of similarity which define whether examples are
positive, negative, or ignored: positive, negative, or ignored, for example:
(1) similarity >= positive_threshold: Highest similarity. Matched/Positive! (1) thresholds=[negative_threshold, positive_threshold], and
(2) positive_threshold > similarity >= negative_threshold: Medium similarity. indicators=[negative_value, ignore_value, positive_value]: The similarity
This is Ignored. metrics below negative_threshold will be assigned with negative_value,
(3) negative_threshold > similarity: Lowest similarity for Negative Match. the metrics between negative_threshold and positive_threshold will be
For ignored matches this class sets the values in the Match object to -2. assigned ignore_value, and the metrics above positive_threshold will be
assigned positive_value.
(2) thresholds=[negative_threshold, positive_threshold], and
indicators=[ignore_value, negative_value, positive_value]: The similarity
metric below negative_threshold will be assigned with ignore_value,
the metrics between negative_threshold and positive_threshold will be
assigned negative_value, and the metrics above positive_threshold will be
assigned positive_value.
""" """
def __init__( def __init__(self, thresholds, indicators, force_match_for_each_col=False):
self,
positive_threshold,
negative_threshold=None,
force_match_for_each_col=False,
negative_lower_than_ignore=True,
positive_value=1,
negative_value=-1,
ignore_value=-2):
"""Construct BoxMatcher. """Construct BoxMatcher.
Args: Args:
positive_threshold: Threshold for positive matches. Positive if thresholds: A list of thresholds to classify boxes into
sim >= positive_threshold, where sim is the maximum value of the different buckets. The list needs to be sorted, and will be prepended
similarity matrix for a given column. Set to None for no threshold. with -Inf and appended with +Inf.
negative_threshold: Threshold for negative matches. Negative if indicators: A list of values to assign for each bucket. len(`indicators`)
sim < negative_threshold or must equal to len(`thresholds`) + 1.
positive_threshold > sim >= negative_threshold.
Defaults to positive_threshold when set to None.
force_match_for_each_col: If True, ensures that each column is matched to force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. positive_threshold is high). Defaults to False. If True, all force
negative_lower_than_ignore: If True, the threshold is matched row will be assigned to `indicators[-1]`.
positive|ignore|negative, else positive|negative|ignore. Defaults to
True.
positive_value: An integer to fill for positive match labels.
negative_value: An integer to fill for negative match labels.
ignore_value: An integer to fill for ignored match labels.
Raises: Raises:
ValueError: If negative_threshold > positive_threshold. ValueError: If `threshold` not sorted,
or len(indicators) != len(threshold) + 1
""" """
self._positive_threshold = positive_threshold if not all([lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:])]):
if negative_threshold is None: raise ValueError('`threshold` must be sorted, got {}'.format(thresholds))
self._negative_threshold = positive_threshold self.indicators = indicators
else: if len(indicators) != len(thresholds) + 1:
if negative_threshold > positive_threshold: raise ValueError('len(`indicators`) must be len(`thresholds`) + 1, got '
raise ValueError('negative_threshold needs to be smaller or equal' 'indicators {}, thresholds {}'.format(
'to positive_threshold') indicators, thresholds))
self._negative_threshold = negative_threshold thresholds = thresholds[:]
thresholds.insert(0, -float('inf'))
self._positive_value = positive_value thresholds.append(float('inf'))
self._negative_value = negative_value self.thresholds = thresholds
self._ignore_value = ignore_value
self._force_match_for_each_col = force_match_for_each_col self._force_match_for_each_col = force_match_for_each_col
self._negative_lower_than_ignore = negative_lower_than_ignore
def __call__(self, similarity_matrix): def __call__(self, similarity_matrix):
"""Tries to match each column of the similarity matrix to a row. """Tries to match each column of the similarity matrix to a row.
...@@ -117,8 +108,7 @@ class BoxMatcher: ...@@ -117,8 +108,7 @@ class BoxMatcher:
""" """
with tf.name_scope('empty_gt_boxes'): with tf.name_scope('empty_gt_boxes'):
matches = tf.zeros([batch_size, num_rows], dtype=tf.int32) matches = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_labels = self._negative_value * tf.ones( match_labels = -tf.ones([batch_size, num_rows], dtype=tf.int32)
[batch_size, num_rows], dtype=tf.int32)
return matches, match_labels return matches, match_labels
def _match_when_rows_are_non_empty(): def _match_when_rows_are_non_empty():
...@@ -133,28 +123,18 @@ class BoxMatcher: ...@@ -133,28 +123,18 @@ class BoxMatcher:
# Get logical indices of ignored and unmatched columns as tf.int64 # Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=-1) matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_labels = self._positive_value * tf.ones( matched_indicators = tf.zeros([batch_size, num_rows], tf.int32)
[batch_size, num_rows], tf.int32)
match_dtype = matched_vals.dtype
positive_threshold = tf.cast( for (ind, low, high) in zip(self.indicators, self.thresholds[:-1],
self._positive_threshold, matched_vals.dtype) self.thresholds[1:]):
negative_threshold = tf.cast( low_threshold = tf.cast(low, match_dtype)
self._negative_threshold, matched_vals.dtype) high_threshold = tf.cast(high, match_dtype)
below_negative_threshold = tf.greater(negative_threshold, matched_vals) mask = tf.logical_and(
between_thresholds = tf.logical_and( tf.greater_equal(matched_vals, low_threshold),
tf.greater_equal(matched_vals, negative_threshold), tf.less(matched_vals, high_threshold))
tf.greater(positive_threshold, matched_vals)) matched_indicators = self._set_values_using_indicator(
matched_indicators, mask, ind)
if self._negative_lower_than_ignore:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._negative_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._ignore_value)
else:
matched_labels = self._set_values_using_indicator(
matched_labels, below_negative_threshold, self._ignore_value)
matched_labels = self._set_values_using_indicator(
matched_labels, between_thresholds, self._negative_value)
if self._force_match_for_each_col: if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best # [batch_size, M], for each col (groundtruth_box), find the best
...@@ -175,27 +155,26 @@ class BoxMatcher: ...@@ -175,27 +155,26 @@ class BoxMatcher:
# [batch_size, N] # [batch_size, N]
final_matches = tf.where(force_match_column_mask, force_match_row_ids, final_matches = tf.where(force_match_column_mask, force_match_row_ids,
matches) matches)
final_matched_labels = tf.where( final_matched_indicators = tf.where(
force_match_column_mask, force_match_column_mask, self.indicators[-1] *
self._positive_value * tf.ones( tf.ones([batch_size, num_rows], dtype=tf.int32),
[batch_size, num_rows], dtype=tf.int32), matched_indicators)
matched_labels) return final_matches, final_matched_indicators
return final_matches, final_matched_labels
else: else:
return matches, matched_labels return matches, matched_indicators
num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape( num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[-1] similarity_matrix)[-1]
result_match, result_match_labels = tf.cond( result_match, result_matched_indicators = tf.cond(
pred=tf.greater(num_gt_boxes, 0), pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty, true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty) false_fn=_match_when_rows_are_empty)
if squeeze_result: if squeeze_result:
result_match = tf.squeeze(result_match, axis=0) result_match = tf.squeeze(result_match, axis=0)
result_match_labels = tf.squeeze(result_match_labels, axis=0) result_matched_indicators = tf.squeeze(result_matched_indicators, axis=0)
return result_match, result_match_labels return result_match, result_matched_indicators
def _set_values_using_indicator(self, x, indicator, val): def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val. """Set the indicated fields of x to val.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment