Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
f0686bcb
Commit
f0686bcb
authored
May 31, 2022
by
Baumgartner, Michael
Browse files
split matcher file
parent
938687e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
419 additions
and
0 deletions
+419
-0
nndet/core/boxes/matcher/__init__.py
nndet/core/boxes/matcher/__init__.py
+3
-0
nndet/core/boxes/matcher/atss.py
nndet/core/boxes/matcher/atss.py
+148
-0
nndet/core/boxes/matcher/base.py
nndet/core/boxes/matcher/base.py
+104
-0
nndet/core/boxes/matcher/iou.py
nndet/core/boxes/matcher/iou.py
+164
-0
No files found.
nndet/core/boxes/matcher/__init__.py
0 → 100644
View file @
f0686bcb
from
nndet.core.boxes.matcher.base
import
Matcher
,
MatcherType
from
nndet.core.boxes.matcher.iou
import
IoUMatcher
from
nndet.core.boxes.matcher.atss
import
ATSSMatcher
nndet/core/boxes/matcher/atss.py
0 → 100644
View file @
f0686bcb
"""
Modifications licensed under
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Parts of this code is adapted from mmdetection and thus licensed under
Copyright 2018-2023 OpenMMLab.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
typing
import
Sequence
,
Callable
,
Tuple
import
torch
from
torch
import
Tensor
from
loguru
import
logger
from
nndet.core.boxes.ops
import
box_iou
,
box_center_dist
,
center_in_boxes
from
nndet.core.boxes.matcher.base
import
Matcher
INF
=
100
# not really inv but here it is sufficient
class
ATSSMatcher
(
Matcher
):
def
__init__
(
self
,
num_candidates
:
int
,
similarity_fn
:
Callable
[[
Tensor
,
Tensor
],
Tensor
]
=
box_iou
,
center_in_gt
:
bool
=
True
,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super
().
__init__
(
similarity_fn
=
similarity_fn
)
self
.
num_candidates
=
num_candidates
self
.
min_dist
=
0.01
self
.
center_in_gt
=
center_in_gt
logger
.
info
(
f
"Running ATSS Matching with num_candidates=
{
self
.
num_candidates
}
"
f
"and center_in_gt
{
self
.
center_in_gt
}
."
)
def
compute_matches
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches according to ATTS for a single image
Adapted from
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt
=
boxes
.
shape
[
0
]
num_anchors
=
anchors
.
shape
[
0
]
distances
,
_
,
anchors_center
=
box_center_dist
(
boxes
,
anchors
)
# num_boxes x anchors
# select candidates based on center distance
candidate_idx
=
[]
start_idx
=
0
for
level
,
apl
in
enumerate
(
num_anchors_per_level
):
end_idx
=
start_idx
+
apl
selectable_k
=
min
(
self
.
num_candidates
*
num_anchors_per_loc
,
apl
)
_
,
idx
=
distances
[:,
start_idx
:
end_idx
].
topk
(
selectable_k
,
dim
=
1
,
largest
=
False
)
# idx shape [num_boxes x selectable_k]
candidate_idx
.
append
(
idx
+
start_idx
)
start_idx
=
end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx
=
torch
.
cat
(
candidate_idx
,
dim
=
1
)
match_quality_matrix
=
self
.
similarity_fn
(
boxes
,
anchors
)
# [num_boxes x anchors]
candidate_overlaps
=
match_quality_matrix
.
gather
(
1
,
candidate_idx
)
# [num_boxes, n_candidates]
# compute adaptive iou threshold
overlaps_mean_per_gt
=
candidate_overlaps
.
mean
(
dim
=
1
)
# [num_boxes]
overlaps_std_per_gt
=
candidate_overlaps
.
std
(
dim
=
1
)
# [num_boxes]
overlaps_thr_per_gt
=
overlaps_mean_per_gt
+
overlaps_std_per_gt
# [num_boxes]
is_pos
=
candidate_overlaps
>=
overlaps_thr_per_gt
[:,
None
]
# [num_boxes x n_candidates]
if
self
.
center_in_gt
:
# can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx
=
torch
.
arange
(
num_gt
,
device
=
boxes
.
device
,
dtype
=
torch
.
long
)[:,
None
]
\
.
expand_as
(
candidate_idx
).
contiguous
()
# [num_boxes x n_candidates]
is_in_gt
=
center_in_boxes
(
anchors_center
[
candidate_idx
.
view
(
-
1
)],
boxes
[
boxes_idx
.
view
(
-
1
)],
eps
=
self
.
min_dist
)
is_pos
=
is_pos
&
is_in_gt
.
view_as
(
is_pos
)
# [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
for
ng
in
range
(
num_gt
):
candidate_idx
[
ng
,
:]
+=
ng
*
num_anchors
overlaps_inf
=
torch
.
full_like
(
match_quality_matrix
,
-
INF
).
view
(
-
1
)
index
=
candidate_idx
.
view
(
-
1
)[
is_pos
.
view
(
-
1
)]
overlaps_inf
[
index
]
=
match_quality_matrix
.
view
(
-
1
)[
index
]
overlaps_inf
=
overlaps_inf
.
view_as
(
match_quality_matrix
)
matched_vals
,
matches
=
overlaps_inf
.
max
(
dim
=
0
)
matches
[
matched_vals
==
-
INF
]
=
self
.
BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {overlaps_thr_per_gt}")
return
match_quality_matrix
,
matches
nndet/core/boxes/matcher/base.py
0 → 100644
View file @
f0686bcb
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
typing
import
Sequence
,
Callable
,
Tuple
,
TypeVar
from
abc
import
ABC
import
torch
from
torch
import
Tensor
from
nndet.core.boxes.ops
import
box_iou
class
Matcher
(
ABC
):
BELOW_LOW_THRESHOLD
:
int
=
-
1
BETWEEN_THRESHOLDS
:
int
=
-
2
def
__init__
(
self
,
similarity_fn
:
Callable
[[
Tensor
,
Tensor
],
Tensor
]
=
box_iou
):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self
.
similarity_fn
=
similarity_fn
def
__call__
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if
boxes
.
numel
()
==
0
:
# no ground truth
num_anchors
=
anchors
.
shape
[
0
]
match_quality_matrix
=
torch
.
tensor
([]).
to
(
anchors
)
matches
=
torch
.
empty
(
num_anchors
,
dtype
=
torch
.
int64
).
fill_
(
self
.
BELOW_LOW_THRESHOLD
)
return
match_quality_matrix
,
matches
else
:
# at least one ground truth
return
self
.
compute_matches
(
boxes
=
boxes
,
anchors
=
anchors
,
num_anchors_per_level
=
num_anchors_per_level
,
num_anchors_per_loc
=
num_anchors_per_loc
,
)
def
compute_matches
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise
NotImplementedError
MatcherType
=
TypeVar
(
'MatcherType'
,
bound
=
Matcher
)
nndet/core/boxes/matcher.py
→
nndet/core/boxes/matcher
/iou
.py
View file @
f0686bcb
"""
Modifications licensed under
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
...
...
@@ -30,96 +47,16 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from
typing
import
Sequence
,
Callable
,
Tuple
,
TypeVar
from
abc
import
ABC
from
typing
import
Callable
,
Tuple
import
torch
from
torch
import
Tensor
from
loguru
import
logger
from
nndet.core.boxes.ops
import
box_iou
,
box_center_dist
,
center_in_boxes
INF
=
100
# not really inv but here it is sufficient
class
Matcher
(
ABC
):
BELOW_LOW_THRESHOLD
:
int
=
-
1
BETWEEN_THRESHOLDS
:
int
=
-
2
def
__init__
(
self
,
similarity_fn
:
Callable
[[
Tensor
,
Tensor
],
Tensor
]
=
box_iou
):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self
.
similarity_fn
=
similarity_fn
def
__call__
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if
boxes
.
numel
()
==
0
:
# no ground truth
num_anchors
=
anchors
.
shape
[
0
]
match_quality_matrix
=
torch
.
tensor
([]).
to
(
anchors
)
matches
=
torch
.
empty
(
num_anchors
,
dtype
=
torch
.
int64
).
fill_
(
self
.
BELOW_LOW_THRESHOLD
)
return
match_quality_matrix
,
matches
else
:
# at least one ground truth
return
self
.
compute_matches
(
boxes
=
boxes
,
anchors
=
anchors
,
num_anchors_per_level
=
num_anchors_per_level
,
num_anchors_per_loc
=
num_anchors_per_loc
,
)
def
compute_matches
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise
NotImplementedError
from
nndet.core.boxes.ops
import
box_iou
from
nndet.core.boxes.matcher.base
import
Matcher
class
IoUMatcher
(
Matcher
):
...
...
@@ -225,110 +162,3 @@ class IoUMatcher(Matcher):
logger
.
info
(
f
"Inbetween IoU ranging from
{
match_bet_min
}
to
{
match_bet_max
}
"
)
logger
.
info
(
f
"Max background IoU:
{
matched_vals
[
below_low_threshold
].
max
()
}
"
)
logger
.
info
(
"#################################"
)
class
ATSSMatcher
(
Matcher
):
def
__init__
(
self
,
num_candidates
:
int
,
similarity_fn
:
Callable
[[
Tensor
,
Tensor
],
Tensor
]
=
box_iou
,
center_in_gt
:
bool
=
True
,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super
().
__init__
(
similarity_fn
=
similarity_fn
)
self
.
num_candidates
=
num_candidates
self
.
min_dist
=
0.01
self
.
center_in_gt
=
center_in_gt
logger
.
info
(
f
"Running ATSS Matching with num_candidates=
{
self
.
num_candidates
}
"
f
"and center_in_gt
{
self
.
center_in_gt
}
."
)
def
compute_matches
(
self
,
boxes
:
torch
.
Tensor
,
anchors
:
torch
.
Tensor
,
num_anchors_per_level
:
Sequence
[
int
],
num_anchors_per_loc
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compute matches according to ATTS for a single image
Adapted from
(https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss
/loss.py#L180-L184)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt
=
boxes
.
shape
[
0
]
num_anchors
=
anchors
.
shape
[
0
]
distances
,
boxes_center
,
anchors_center
=
box_center_dist
(
boxes
,
anchors
)
# num_boxes x anchors
# select candidates based on center distance
candidate_idx
=
[]
start_idx
=
0
for
level
,
apl
in
enumerate
(
num_anchors_per_level
):
end_idx
=
start_idx
+
apl
topk
=
min
(
self
.
num_candidates
*
num_anchors_per_loc
,
apl
)
_
,
idx
=
distances
[:,
start_idx
:
end_idx
].
topk
(
topk
,
dim
=
1
,
largest
=
False
)
# idx shape [num_boxes x topk]
candidate_idx
.
append
(
idx
+
start_idx
)
start_idx
=
end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx
=
torch
.
cat
(
candidate_idx
,
dim
=
1
)
match_quality_matrix
=
self
.
similarity_fn
(
boxes
,
anchors
)
# [num_boxes x anchors]
candidate_ious
=
match_quality_matrix
.
gather
(
1
,
candidate_idx
)
# [num_boxes, n_candidates]
# compute adaptive iou threshold
iou_mean_per_gt
=
candidate_ious
.
mean
(
dim
=
1
)
# [num_boxes]
iou_std_per_gt
=
candidate_ious
.
std
(
dim
=
1
)
# [num_boxes]
iou_thresh_per_gt
=
iou_mean_per_gt
+
iou_std_per_gt
# [num_boxes]
is_pos
=
candidate_ious
>=
iou_thresh_per_gt
[:,
None
]
# [num_boxes x n_candidates]
if
self
.
center_in_gt
:
# can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx
=
torch
.
arange
(
num_gt
,
device
=
boxes
.
device
,
dtype
=
torch
.
long
)[:,
None
]
\
.
expand_as
(
candidate_idx
).
contiguous
()
# [num_boxes x n_candidates]
is_in_gt
=
center_in_boxes
(
anchors_center
[
candidate_idx
.
view
(
-
1
)],
boxes
[
boxes_idx
.
view
(
-
1
)],
eps
=
self
.
min_dist
)
is_pos
=
is_pos
&
is_in_gt
.
view_as
(
is_pos
)
# [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
# TODO: think about a better way to do this
for
ng
in
range
(
num_gt
):
candidate_idx
[
ng
,
:]
+=
ng
*
num_anchors
ious_inf
=
torch
.
full_like
(
match_quality_matrix
,
-
INF
).
view
(
-
1
)
index
=
candidate_idx
.
view
(
-
1
)[
is_pos
.
view
(
-
1
)]
ious_inf
[
index
]
=
match_quality_matrix
.
view
(
-
1
)[
index
]
ious_inf
=
ious_inf
.
view_as
(
match_quality_matrix
)
matched_vals
,
matches
=
ious_inf
.
max
(
dim
=
0
)
matches
[
matched_vals
==
-
INF
]
=
self
.
BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
return
match_quality_matrix
,
matches
MatcherType
=
TypeVar
(
'MatcherType'
,
bound
=
Matcher
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment