Commit f0686bcb authored by Baumgartner, Michael's avatar Baumgartner, Michael
Browse files

split matcher file

parent 938687e7
from nndet.core.boxes.matcher.base import Matcher, MatcherType
from nndet.core.boxes.matcher.iou import IoUMatcher
from nndet.core.boxes.matcher.atss import ATSSMatcher
"""
Modifications licensed under
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Parts of this code is adapted from mmdetection and thus licensed under
Copyright 2018-2023 OpenMMLab.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, Callable, Tuple
import torch
from torch import Tensor
from loguru import logger
from nndet.core.boxes.ops import box_iou, box_center_dist, center_in_boxes
from nndet.core.boxes.matcher.base import Matcher
INF = 100 # not really inv but here it is sufficient
class ATSSMatcher(Matcher):
def __init__(self,
num_candidates: int,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
center_in_gt: bool = True,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super().__init__(similarity_fn=similarity_fn)
self.num_candidates = num_candidates
self.min_dist = 0.01
self.center_in_gt = center_in_gt
logger.info(f"Running ATSS Matching with num_candidates={self.num_candidates} "
f"and center_in_gt {self.center_in_gt}.")
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to ATTS for a single image
Adapted from
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt = boxes.shape[0]
num_anchors = anchors.shape[0]
distances, _, anchors_center = box_center_dist(boxes, anchors) # num_boxes x anchors
# select candidates based on center distance
candidate_idx = []
start_idx = 0
for level, apl in enumerate(num_anchors_per_level):
end_idx = start_idx + apl
selectable_k = min(self.num_candidates * num_anchors_per_loc, apl)
_, idx = distances[:, start_idx: end_idx].topk(selectable_k, dim=1, largest=False)
# idx shape [num_boxes x selectable_k]
candidate_idx.append(idx + start_idx)
start_idx = end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx = torch.cat(candidate_idx, dim=1)
match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors]
candidate_overlaps = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates]
# compute adaptive iou threshold
overlaps_mean_per_gt = candidate_overlaps.mean(dim=1) # [num_boxes]
overlaps_std_per_gt = candidate_overlaps.std(dim=1) # [num_boxes]
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt # [num_boxes]
is_pos = candidate_overlaps >= overlaps_thr_per_gt[:, None] # [num_boxes x n_candidates]
if self.center_in_gt: # can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx = torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\
.expand_as(candidate_idx).contiguous() # [num_boxes x n_candidates]
is_in_gt = center_in_boxes(
anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist)
is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
for ng in range(num_gt):
candidate_idx[ng, :] += ng * num_anchors
overlaps_inf = torch.full_like(match_quality_matrix, -INF).view(-1)
index = candidate_idx.view(-1)[is_pos.view(-1)]
overlaps_inf[index] = match_quality_matrix.view(-1)[index]
overlaps_inf = overlaps_inf.view_as(match_quality_matrix)
matched_vals, matches = overlaps_inf.max(dim=0)
matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {overlaps_thr_per_gt}")
return match_quality_matrix, matches
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Sequence, Callable, Tuple, TypeVar
from abc import ABC
import torch
from torch import Tensor
from nndet.core.boxes.ops import box_iou
class Matcher(ABC):
BELOW_LOW_THRESHOLD: int = -1
BETWEEN_THRESHOLDS: int = -2
def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self.similarity_fn = similarity_fn
def __call__(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if boxes.numel() == 0:
# no ground truth
num_anchors = anchors.shape[0]
match_quality_matrix = torch.tensor([]).to(anchors)
matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)
return match_quality_matrix, matches
else:
# at least one ground truth
return self.compute_matches(
boxes=boxes, anchors=anchors,
num_anchors_per_level=num_anchors_per_level,
num_anchors_per_loc=num_anchors_per_loc,
)
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise NotImplementedError
MatcherType = TypeVar('MatcherType', bound=Matcher)
"""
Modifications licensed under
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
......@@ -30,96 +47,16 @@ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from typing import Sequence, Callable, Tuple, TypeVar
from abc import ABC
from typing import Callable, Tuple
import torch
from torch import Tensor
from loguru import logger
from nndet.core.boxes.ops import box_iou, box_center_dist, center_in_boxes
INF = 100 # not really inv but here it is sufficient
class Matcher(ABC):
BELOW_LOW_THRESHOLD: int = -1
BETWEEN_THRESHOLDS: int = -2
def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self.similarity_fn = similarity_fn
def __call__(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if boxes.numel() == 0:
# no ground truth
num_anchors = anchors.shape[0]
match_quality_matrix = torch.tensor([]).to(anchors)
matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)
return match_quality_matrix, matches
else:
# at least one ground truth
return self.compute_matches(
boxes=boxes, anchors=anchors,
num_anchors_per_level=num_anchors_per_level,
num_anchors_per_loc=num_anchors_per_loc,
)
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise NotImplementedError
from nndet.core.boxes.ops import box_iou
from nndet.core.boxes.matcher.base import Matcher
class IoUMatcher(Matcher):
......@@ -225,110 +162,3 @@ class IoUMatcher(Matcher):
logger.info(f"Inbetween IoU ranging from {match_bet_min} to {match_bet_max}")
logger.info(f"Max background IoU: {matched_vals[below_low_threshold].max()}")
logger.info("#################################")
class ATSSMatcher(Matcher):
def __init__(self,
num_candidates: int,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
center_in_gt: bool = True,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super().__init__(similarity_fn=similarity_fn)
self.num_candidates = num_candidates
self.min_dist = 0.01
self.center_in_gt = center_in_gt
logger.info(f"Running ATSS Matching with num_candidates={self.num_candidates} "
f"and center_in_gt {self.center_in_gt}.")
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to ATTS for a single image
Adapted from
(https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss
/loss.py#L180-L184)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt = boxes.shape[0]
num_anchors = anchors.shape[0]
distances, boxes_center, anchors_center = box_center_dist(boxes, anchors) # num_boxes x anchors
# select candidates based on center distance
candidate_idx = []
start_idx = 0
for level, apl in enumerate(num_anchors_per_level):
end_idx = start_idx + apl
topk = min(self.num_candidates * num_anchors_per_loc, apl)
_, idx = distances[:, start_idx: end_idx].topk(topk, dim=1, largest=False)
# idx shape [num_boxes x topk]
candidate_idx.append(idx + start_idx)
start_idx = end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx = torch.cat(candidate_idx, dim=1)
match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors]
candidate_ious = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates]
# compute adaptive iou threshold
iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes]
iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes]
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes]
is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates]
if self.center_in_gt: # can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx = torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\
.expand_as(candidate_idx).contiguous() # [num_boxes x n_candidates]
is_in_gt = center_in_boxes(
anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist)
is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
# TODO: think about a better way to do this
for ng in range(num_gt):
candidate_idx[ng, :] += ng * num_anchors
ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1)
index = candidate_idx.view(-1)[is_pos.view(-1)]
ious_inf[index] = match_quality_matrix.view(-1)[index]
ious_inf = ious_inf.view_as(match_quality_matrix)
matched_vals, matches = ious_inf.max(dim=0)
matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
return match_quality_matrix, matches
MatcherType = TypeVar('MatcherType', bound=Matcher)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment