Commit c0bce36e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 447823991
parent eb6e0ac4
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Box matcher implementation.""" """Box matcher implementation."""
from typing import List, Tuple
import tensorflow as tf import tensorflow as tf
...@@ -43,15 +43,19 @@ class BoxMatcher: ...@@ -43,15 +43,19 @@ class BoxMatcher:
assigned positive_value. assigned positive_value.
""" """
def __init__(self, thresholds, indicators, force_match_for_each_col=False): def __init__(self,
thresholds: List[float],
indicators: List[int],
force_match_for_each_col: bool = False):
"""Construct BoxMatcher. """Construct BoxMatcher.
Args: Args:
thresholds: A list of thresholds to classify boxes into thresholds: A list of thresholds to classify the matches into different
different buckets. The list needs to be sorted, and will be prepended types (e.g. positive or negative or ignored match). The list needs to be
with -Inf and appended with +Inf. sorted, and will be prepended with -Inf and appended with +Inf.
indicators: A list of values to assign for each bucket. len(`indicators`) indicators: A list of values representing match types (e.g. positive or
must equal to len(`thresholds`) + 1. negative or ignored match). len(`indicators`) must equal to
len(`thresholds`) + 1.
force_match_for_each_col: If True, ensures that each column is matched to force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. If True, all force positive_threshold is high). Defaults to False. If True, all force
...@@ -74,19 +78,20 @@ class BoxMatcher: ...@@ -74,19 +78,20 @@ class BoxMatcher:
self.thresholds = thresholds self.thresholds = thresholds
self._force_match_for_each_col = force_match_for_each_col self._force_match_for_each_col = force_match_for_each_col
def __call__(self, similarity_matrix): def __call__(self,
similarity_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Tries to match each column of the similarity matrix to a row. """Tries to match each column of the similarity matrix to a row.
Args: Args:
similarity_matrix: A float tensor of shape [N, M] representing any similarity_matrix: A float tensor of shape [num_rows, num_cols] or
similarity metric. [batch_size, num_rows, num_cols] representing any similarity metric.
Returns: Returns:
A integer tensor of shape [N] with corresponding match indices for each matched_columns: An integer tensor of shape [num_rows] or [batch_size,
of M columns, for positive match, the match result will be the num_rows] storing the index of the matched column for each row.
corresponding row index, for negative match, the match will be match_indicators: An integer tensor of shape [num_rows] or [batch_size,
`negative_value`, for ignored match, the match result will be num_rows] storing the match type indicator (e.g. positive or negative or
`ignore_value`. ignored match).
""" """
squeeze_result = False squeeze_result = False
if len(similarity_matrix.shape) == 2: if len(similarity_matrix.shape) == 2:
...@@ -101,29 +106,37 @@ class BoxMatcher: ...@@ -101,29 +106,37 @@ class BoxMatcher:
"""Performs matching when the rows of similarity matrix are empty. """Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the columns do not match to any rows. a tensor of -1's to indicate that the rows do not match to any columns.
Returns: Returns:
matches: int32 tensor indicating the row each column matches to. matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
""" """
with tf.name_scope('empty_gt_boxes'): with tf.name_scope('empty_gt_boxes'):
matches = tf.zeros([batch_size, num_rows], dtype=tf.int32) matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
match_labels = -tf.ones([batch_size, num_rows], dtype=tf.int32) match_indicators = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matches, match_labels return matched_columns, match_indicators
def _match_when_rows_are_non_empty(): def _match_when_rows_are_non_empty():
"""Performs matching when the rows of similarity matrix are non empty. """Performs matching when the rows of similarity matrix are non empty.
Returns: Returns:
matches: int32 tensor indicating the row each column matches to. matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
""" """
# Matches for each column
with tf.name_scope('non_empty_gt_boxes'): with tf.name_scope('non_empty_gt_boxes'):
matches = tf.argmax(similarity_matrix, axis=-1, output_type=tf.int32) matched_columns = tf.argmax(
similarity_matrix, axis=-1, output_type=tf.int32)
# Get logical indices of ignored and unmatched columns as tf.int64 # Get logical indices of ignored and unmatched columns as tf.int64
matched_vals = tf.reduce_max(similarity_matrix, axis=-1) matched_vals = tf.reduce_max(similarity_matrix, axis=-1)
matched_indicators = tf.zeros([batch_size, num_rows], tf.int32) match_indicators = tf.zeros([batch_size, num_rows], tf.int32)
match_dtype = matched_vals.dtype match_dtype = matched_vals.dtype
for (ind, low, high) in zip(self.indicators, self.thresholds[:-1], for (ind, low, high) in zip(self.indicators, self.thresholds[:-1],
...@@ -133,48 +146,46 @@ class BoxMatcher: ...@@ -133,48 +146,46 @@ class BoxMatcher:
mask = tf.logical_and( mask = tf.logical_and(
tf.greater_equal(matched_vals, low_threshold), tf.greater_equal(matched_vals, low_threshold),
tf.less(matched_vals, high_threshold)) tf.less(matched_vals, high_threshold))
matched_indicators = self._set_values_using_indicator( match_indicators = self._set_values_using_indicator(
matched_indicators, mask, ind) match_indicators, mask, ind)
if self._force_match_for_each_col: if self._force_match_for_each_col:
# [batch_size, M], for each col (groundtruth_box), find the best # [batch_size, num_cols], for each column (groundtruth_box), find the
# matching row (anchor). # best matching row (anchor).
force_match_column_ids = tf.argmax( matching_rows = tf.argmax(
input=similarity_matrix, axis=1, output_type=tf.int32) input=similarity_matrix, axis=1, output_type=tf.int32)
# [batch_size, M, N] # [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
force_match_column_indicators = tf.one_hot( # where M[j, i] = 1 means column j is matched to row i.
force_match_column_ids, depth=num_rows) column_to_row_match_mapping = tf.one_hot(
# [batch_size, N], for each row (anchor), find the largest column matching_rows, depth=num_rows)
# index for groundtruth box # [batch_size, num_rows], for each row (anchor), find the matched
force_match_row_ids = tf.argmax( # column (groundtruth_box).
input=force_match_column_indicators, axis=1, output_type=tf.int32) force_matched_columns = tf.argmax(
# [batch_size, N] input=column_to_row_match_mapping, axis=1, output_type=tf.int32)
force_match_column_mask = tf.cast( # [batch_size, num_rows]
tf.reduce_max(force_match_column_indicators, axis=1), force_matched_column_mask = tf.cast(
tf.bool) tf.reduce_max(column_to_row_match_mapping, axis=1), tf.bool)
# [batch_size, N] # [batch_size, num_rows]
final_matches = tf.where(force_match_column_mask, force_match_row_ids, matched_columns = tf.where(force_matched_column_mask,
matches) force_matched_columns, matched_columns)
final_matched_indicators = tf.where( match_indicators = tf.where(
force_match_column_mask, self.indicators[-1] * force_matched_column_mask, self.indicators[-1] *
tf.ones([batch_size, num_rows], dtype=tf.int32), tf.ones([batch_size, num_rows], dtype=tf.int32), match_indicators)
matched_indicators)
return final_matches, final_matched_indicators return matched_columns, match_indicators
else:
return matches, matched_indicators
num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape( num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape(
similarity_matrix)[-1] similarity_matrix)[-1]
result_match, result_matched_indicators = tf.cond( matched_columns, match_indicators = tf.cond(
pred=tf.greater(num_gt_boxes, 0), pred=tf.greater(num_gt_boxes, 0),
true_fn=_match_when_rows_are_non_empty, true_fn=_match_when_rows_are_non_empty,
false_fn=_match_when_rows_are_empty) false_fn=_match_when_rows_are_empty)
if squeeze_result: if squeeze_result:
result_match = tf.squeeze(result_match, axis=0) matched_columns = tf.squeeze(matched_columns, axis=0)
result_matched_indicators = tf.squeeze(result_matched_indicators, axis=0) match_indicators = tf.squeeze(match_indicators, axis=0)
return result_match, result_matched_indicators return matched_columns, match_indicators
def _set_values_using_indicator(self, x, indicator, val): def _set_values_using_indicator(self, x, indicator, val):
"""Set the indicated fields of x to val. """Set the indicated fields of x to val.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment