Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c0bce36e
Commit
c0bce36e
authored
May 10, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 447823991
parent
eb6e0ac4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
54 deletions
+65
-54
official/vision/ops/box_matcher.py
official/vision/ops/box_matcher.py
+65
-54
No files found.
official/vision/ops/box_matcher.py
View file @
c0bce36e
...
...
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Box matcher implementation."""
from
typing
import
List
,
Tuple
import
tensorflow
as
tf
...
...
@@ -43,15 +43,19 @@ class BoxMatcher:
assigned positive_value.
"""
def
__init__
(
self
,
thresholds
,
indicators
,
force_match_for_each_col
=
False
):
def
__init__
(
self
,
thresholds
:
List
[
float
],
indicators
:
List
[
int
],
force_match_for_each_col
:
bool
=
False
):
"""Construct BoxMatcher.
Args:
thresholds: A list of thresholds to classify boxes into
different buckets. The list needs to be sorted, and will be prepended
with -Inf and appended with +Inf.
indicators: A list of values to assign for each bucket. len(`indicators`)
must equal to len(`thresholds`) + 1.
thresholds: A list of thresholds to classify the matches into different
types (e.g. positive or negative or ignored match). The list needs to be
sorted, and will be prepended with -Inf and appended with +Inf.
indicators: A list of values representing match types (e.g. positive or
negative or ignored match). len(`indicators`) must equal to
len(`thresholds`) + 1.
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. If True, all force
...
...
@@ -74,19 +78,20 @@ class BoxMatcher:
self
.
thresholds
=
thresholds
self
.
_force_match_for_each_col
=
force_match_for_each_col
def
__call__
(
self
,
similarity_matrix
):
def
__call__
(
self
,
similarity_matrix
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Tries to match each column of the similarity matrix to a row.
Args:
similarity_matrix: A float tensor of shape [
N, M] representing any
similarity metric.
similarity_matrix: A float tensor of shape [
num_rows, num_cols] or
[batch_size, num_rows, num_cols] representing any
similarity metric.
Returns:
A
integer tensor of shape [
N] with corresponding match indices for each
of M columns, for positive match, the match result will be the
corresponding row index, for negative match, the match will be
`negative_value`, for ignored
match
,
t
he match result will be
`
ignore
_value`
.
matched_columns: An
integer tensor of shape [
num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the
match t
ype indicator (e.g. positive or negative or
ignore
d match)
.
"""
squeeze_result
=
False
if
len
(
similarity_matrix
.
shape
)
==
2
:
...
...
@@ -101,29 +106,37 @@ class BoxMatcher:
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the
column
s do not match to any
row
s.
a tensor of -1's to indicate that the
row
s do not match to any
column
s.
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with
tf
.
name_scope
(
'empty_gt_boxes'
):
matches
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
match_
label
s
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
return
matches
,
match_
label
s
matche
d_column
s
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
match_
indicator
s
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
return
matche
d_column
s
,
match_
indicator
s
def
_match_when_rows_are_non_empty
():
"""Performs matching when the rows of similarity matrix are non empty.
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
# Matches for each column
with
tf
.
name_scope
(
'non_empty_gt_boxes'
):
matches
=
tf
.
argmax
(
similarity_matrix
,
axis
=-
1
,
output_type
=
tf
.
int32
)
matched_columns
=
tf
.
argmax
(
similarity_matrix
,
axis
=-
1
,
output_type
=
tf
.
int32
)
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
match
ed
_indicators
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
match_indicators
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
match_dtype
=
matched_vals
.
dtype
for
(
ind
,
low
,
high
)
in
zip
(
self
.
indicators
,
self
.
thresholds
[:
-
1
],
...
...
@@ -133,48 +146,46 @@ class BoxMatcher:
mask
=
tf
.
logical_and
(
tf
.
greater_equal
(
matched_vals
,
low_threshold
),
tf
.
less
(
matched_vals
,
high_threshold
))
match
ed
_indicators
=
self
.
_set_values_using_indicator
(
match
ed
_indicators
,
mask
,
ind
)
match_indicators
=
self
.
_set_values_using_indicator
(
match_indicators
,
mask
,
ind
)
if
self
.
_force_match_for_each_col
:
# [batch_size,
M
], for each col (groundtruth_box), find the
best
# matching row (anchor).
force_match_column_id
s
=
tf
.
argmax
(
# [batch_size,
num_cols
], for each col
umn
(groundtruth_box), find the
#
best
matching row (anchor).
matching_row
s
=
tf
.
argmax
(
input
=
similarity_matrix
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, M, N]
force_match_column_indicators
=
tf
.
one_hot
(
force_match_column_ids
,
depth
=
num_rows
)
# [batch_size, N], for each row (anchor), find the largest column
# index for groundtruth box
force_match_row_ids
=
tf
.
argmax
(
input
=
force_match_column_indicators
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, N]
force_match_column_mask
=
tf
.
cast
(
tf
.
reduce_max
(
force_match_column_indicators
,
axis
=
1
),
tf
.
bool
)
# [batch_size, N]
final_matches
=
tf
.
where
(
force_match_column_mask
,
force_match_row_ids
,
matches
)
final_matched_indicators
=
tf
.
where
(
force_match_column_mask
,
self
.
indicators
[
-
1
]
*
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
matched_indicators
)
return
final_matches
,
final_matched_indicators
else
:
return
matches
,
matched_indicators
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
# where M[j, i] = 1 means column j is matched to row i.
column_to_row_match_mapping
=
tf
.
one_hot
(
matching_rows
,
depth
=
num_rows
)
# [batch_size, num_rows], for each row (anchor), find the matched
# column (groundtruth_box).
force_matched_columns
=
tf
.
argmax
(
input
=
column_to_row_match_mapping
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, num_rows]
force_matched_column_mask
=
tf
.
cast
(
tf
.
reduce_max
(
column_to_row_match_mapping
,
axis
=
1
),
tf
.
bool
)
# [batch_size, num_rows]
matched_columns
=
tf
.
where
(
force_matched_column_mask
,
force_matched_columns
,
matched_columns
)
match_indicators
=
tf
.
where
(
force_matched_column_mask
,
self
.
indicators
[
-
1
]
*
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
match_indicators
)
return
matched_columns
,
match_indicators
num_gt_boxes
=
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
similarity_matrix
)[
-
1
]
result_match
,
result_
match
ed
_indicators
=
tf
.
cond
(
matched_columns
,
match_indicators
=
tf
.
cond
(
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
true_fn
=
_match_when_rows_are_non_empty
,
false_fn
=
_match_when_rows_are_empty
)
if
squeeze_result
:
result_match
=
tf
.
squeeze
(
result_match
,
axis
=
0
)
result_
match
ed
_indicators
=
tf
.
squeeze
(
result_
match
ed
_indicators
,
axis
=
0
)
matched_columns
=
tf
.
squeeze
(
matched_columns
,
axis
=
0
)
match_indicators
=
tf
.
squeeze
(
match_indicators
,
axis
=
0
)
return
result_match
,
result_
match
ed
_indicators
return
matched_columns
,
match_indicators
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
"""Set the indicated fields of x to val.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment