Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c0bce36e
Commit
c0bce36e
authored
May 10, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 447823991
parent
eb6e0ac4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
54 deletions
+65
-54
official/vision/ops/box_matcher.py
official/vision/ops/box_matcher.py
+65
-54
No files found.
official/vision/ops/box_matcher.py
View file @
c0bce36e
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Box matcher implementation."""
"""Box matcher implementation."""
from
typing
import
List
,
Tuple
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -43,15 +43,19 @@ class BoxMatcher:
...
@@ -43,15 +43,19 @@ class BoxMatcher:
assigned positive_value.
assigned positive_value.
"""
"""
def
__init__
(
self
,
thresholds
,
indicators
,
force_match_for_each_col
=
False
):
def
__init__
(
self
,
thresholds
:
List
[
float
],
indicators
:
List
[
int
],
force_match_for_each_col
:
bool
=
False
):
"""Construct BoxMatcher.
"""Construct BoxMatcher.
Args:
Args:
thresholds: A list of thresholds to classify boxes into
thresholds: A list of thresholds to classify the matches into different
different buckets. The list needs to be sorted, and will be prepended
types (e.g. positive or negative or ignored match). The list needs to be
with -Inf and appended with +Inf.
sorted, and will be prepended with -Inf and appended with +Inf.
indicators: A list of values to assign for each bucket. len(`indicators`)
indicators: A list of values representing match types (e.g. positive or
must equal to len(`thresholds`) + 1.
negative or ignored match). len(`indicators`) must equal to
len(`thresholds`) + 1.
force_match_for_each_col: If True, ensures that each column is matched to
force_match_for_each_col: If True, ensures that each column is matched to
at least one row (which is not guaranteed otherwise if the
at least one row (which is not guaranteed otherwise if the
positive_threshold is high). Defaults to False. If True, all force
positive_threshold is high). Defaults to False. If True, all force
...
@@ -74,19 +78,20 @@ class BoxMatcher:
...
@@ -74,19 +78,20 @@ class BoxMatcher:
self
.
thresholds
=
thresholds
self
.
thresholds
=
thresholds
self
.
_force_match_for_each_col
=
force_match_for_each_col
self
.
_force_match_for_each_col
=
force_match_for_each_col
def
__call__
(
self
,
similarity_matrix
):
def
__call__
(
self
,
similarity_matrix
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Tries to match each column of the similarity matrix to a row.
"""Tries to match each column of the similarity matrix to a row.
Args:
Args:
similarity_matrix: A float tensor of shape [
N, M] representing any
similarity_matrix: A float tensor of shape [
num_rows, num_cols] or
similarity metric.
[batch_size, num_rows, num_cols] representing any
similarity metric.
Returns:
Returns:
A
integer tensor of shape [
N] with corresponding match indices for each
matched_columns: An
integer tensor of shape [
num_rows] or [batch_size,
of M columns, for positive match, the match result will be the
num_rows] storing the index of the matched column for each row.
corresponding row index, for negative match, the match will be
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
`negative_value`, for ignored
match
,
t
he match result will be
num_rows] storing the
match t
ype indicator (e.g. positive or negative or
`
ignore
_value`
.
ignore
d match)
.
"""
"""
squeeze_result
=
False
squeeze_result
=
False
if
len
(
similarity_matrix
.
shape
)
==
2
:
if
len
(
similarity_matrix
.
shape
)
==
2
:
...
@@ -101,29 +106,37 @@ class BoxMatcher:
...
@@ -101,29 +106,37 @@ class BoxMatcher:
"""Performs matching when the rows of similarity matrix are empty.
"""Performs matching when the rows of similarity matrix are empty.
When the rows are empty, all detections are false positives. So we return
When the rows are empty, all detections are false positives. So we return
a tensor of -1's to indicate that the
column
s do not match to any
row
s.
a tensor of -1's to indicate that the
row
s do not match to any
column
s.
Returns:
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
"""
with
tf
.
name_scope
(
'empty_gt_boxes'
):
with
tf
.
name_scope
(
'empty_gt_boxes'
):
matches
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
matche
d_column
s
=
tf
.
zeros
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
match_
label
s
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
match_
indicator
s
=
-
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
)
return
matches
,
match_
label
s
return
matche
d_column
s
,
match_
indicator
s
def
_match_when_rows_are_non_empty
():
def
_match_when_rows_are_non_empty
():
"""Performs matching when the rows of similarity matrix are non empty.
"""Performs matching when the rows of similarity matrix are non empty.
Returns:
Returns:
matches: int32 tensor indicating the row each column matches to.
matched_columns: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the index of the matched column for each row.
match_indicators: An integer tensor of shape [num_rows] or [batch_size,
num_rows] storing the match type indicator (e.g. positive or negative
or ignored match).
"""
"""
# Matches for each column
with
tf
.
name_scope
(
'non_empty_gt_boxes'
):
with
tf
.
name_scope
(
'non_empty_gt_boxes'
):
matches
=
tf
.
argmax
(
similarity_matrix
,
axis
=-
1
,
output_type
=
tf
.
int32
)
matched_columns
=
tf
.
argmax
(
similarity_matrix
,
axis
=-
1
,
output_type
=
tf
.
int32
)
# Get logical indices of ignored and unmatched columns as tf.int64
# Get logical indices of ignored and unmatched columns as tf.int64
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
matched_vals
=
tf
.
reduce_max
(
similarity_matrix
,
axis
=-
1
)
match
ed
_indicators
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
match_indicators
=
tf
.
zeros
([
batch_size
,
num_rows
],
tf
.
int32
)
match_dtype
=
matched_vals
.
dtype
match_dtype
=
matched_vals
.
dtype
for
(
ind
,
low
,
high
)
in
zip
(
self
.
indicators
,
self
.
thresholds
[:
-
1
],
for
(
ind
,
low
,
high
)
in
zip
(
self
.
indicators
,
self
.
thresholds
[:
-
1
],
...
@@ -133,48 +146,46 @@ class BoxMatcher:
...
@@ -133,48 +146,46 @@ class BoxMatcher:
mask
=
tf
.
logical_and
(
mask
=
tf
.
logical_and
(
tf
.
greater_equal
(
matched_vals
,
low_threshold
),
tf
.
greater_equal
(
matched_vals
,
low_threshold
),
tf
.
less
(
matched_vals
,
high_threshold
))
tf
.
less
(
matched_vals
,
high_threshold
))
match
ed
_indicators
=
self
.
_set_values_using_indicator
(
match_indicators
=
self
.
_set_values_using_indicator
(
match
ed
_indicators
,
mask
,
ind
)
match_indicators
,
mask
,
ind
)
if
self
.
_force_match_for_each_col
:
if
self
.
_force_match_for_each_col
:
# [batch_size,
M
], for each col (groundtruth_box), find the
best
# [batch_size,
num_cols
], for each col
umn
(groundtruth_box), find the
# matching row (anchor).
#
best
matching row (anchor).
force_match_column_id
s
=
tf
.
argmax
(
matching_row
s
=
tf
.
argmax
(
input
=
similarity_matrix
,
axis
=
1
,
output_type
=
tf
.
int32
)
input
=
similarity_matrix
,
axis
=
1
,
output_type
=
tf
.
int32
)
# [batch_size, M, N]
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M,
force_match_column_indicators
=
tf
.
one_hot
(
# where M[j, i] = 1 means column j is matched to row i.
force_match_column_ids
,
depth
=
num_rows
)
column_to_row_match_mapping
=
tf
.
one_hot
(
# [batch_size, N], for each row (anchor), find the largest column
matching_rows
,
depth
=
num_rows
)
# index for groundtruth box
# [batch_size, num_rows], for each row (anchor), find the matched
force_match_row_ids
=
tf
.
argmax
(
# column (groundtruth_box).
input
=
force_match_column_indicators
,
axis
=
1
,
output_type
=
tf
.
int32
)
force_matched_columns
=
tf
.
argmax
(
# [batch_size, N]
input
=
column_to_row_match_mapping
,
axis
=
1
,
output_type
=
tf
.
int32
)
force_match_column_mask
=
tf
.
cast
(
# [batch_size, num_rows]
tf
.
reduce_max
(
force_match_column_indicators
,
axis
=
1
),
force_matched_column_mask
=
tf
.
cast
(
tf
.
bool
)
tf
.
reduce_max
(
column_to_row_match_mapping
,
axis
=
1
),
tf
.
bool
)
# [batch_size, N]
# [batch_size, num_rows]
final_matches
=
tf
.
where
(
force_match_column_mask
,
force_match_row_ids
,
matched_columns
=
tf
.
where
(
force_matched_column_mask
,
matches
)
force_matched_columns
,
matched_columns
)
final_matched_indicators
=
tf
.
where
(
match_indicators
=
tf
.
where
(
force_match_column_mask
,
self
.
indicators
[
-
1
]
*
force_matched_column_mask
,
self
.
indicators
[
-
1
]
*
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
tf
.
ones
([
batch_size
,
num_rows
],
dtype
=
tf
.
int32
),
match_indicators
)
matched_indicators
)
return
final_matches
,
final_matched_indicators
return
matched_columns
,
match_indicators
else
:
return
matches
,
matched_indicators
num_gt_boxes
=
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
num_gt_boxes
=
similarity_matrix
.
shape
.
as_list
()[
-
1
]
or
tf
.
shape
(
similarity_matrix
)[
-
1
]
similarity_matrix
)[
-
1
]
result_match
,
result_
match
ed
_indicators
=
tf
.
cond
(
matched_columns
,
match_indicators
=
tf
.
cond
(
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
pred
=
tf
.
greater
(
num_gt_boxes
,
0
),
true_fn
=
_match_when_rows_are_non_empty
,
true_fn
=
_match_when_rows_are_non_empty
,
false_fn
=
_match_when_rows_are_empty
)
false_fn
=
_match_when_rows_are_empty
)
if
squeeze_result
:
if
squeeze_result
:
result_match
=
tf
.
squeeze
(
result_match
,
axis
=
0
)
matched_columns
=
tf
.
squeeze
(
matched_columns
,
axis
=
0
)
result_
match
ed
_indicators
=
tf
.
squeeze
(
result_
match
ed
_indicators
,
axis
=
0
)
match_indicators
=
tf
.
squeeze
(
match_indicators
,
axis
=
0
)
return
result_match
,
result_
match
ed
_indicators
return
matched_columns
,
match_indicators
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
def
_set_values_using_indicator
(
self
,
x
,
indicator
,
val
):
"""Set the indicated fields of x to val.
"""Set the indicated fields of x to val.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment