nms.py 3.88 KB
Newer Older
mibaumgartner's avatar
core  
mibaumgartner committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch
Michael Baumgartner's avatar
Michael Baumgartner committed
18
from loguru import logger
mibaumgartner's avatar
core  
mibaumgartner committed
19
20
21
22
from torch import Tensor
from torch.cuda.amp import autocast
from torchvision.ops.boxes import nms as nms_2d

Michael Baumgartner's avatar
Michael Baumgartner committed
23
24
25
26
27
try:
    from nndet._C import nms as nms_gpu
except ImportError:
    logger.warning("nnDetection was not build with GPU support!")
    nms_gpu = None
mibaumgartner's avatar
mibaumgartner committed
28
from nndet.core.boxes.ops import box_iou
mibaumgartner's avatar
core  
mibaumgartner committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


def nms_cpu(boxes, scores, thresh):
    """
    Performs non-maximum suppression for 3d boxes on cpu
    
    Args:
        boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
        scores (Tensor): score for each box [N]
        iou_threshold (float): threshould when boxes are discarded
    
    Returns:
        keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS, 
            sorted in decreasing order of scores
    """
    ious = box_iou(boxes, boxes)
    _, _idx = torch.sort(scores, descending=True)
    
    keep = []
    while _idx.nelement() > 0:
        keep.append(_idx[0])
        # get all elements that were not matched and discard all others.
        non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]
        _idx = _idx[non_matches]
    return torch.tensor(keep).to(boxes).long()


@autocast(enabled=False)
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float):
    """
    Performs non-maximum suppression
    
    Args:
        boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
        scores (Tensor): score for each box [N]
        iou_threshold (float): threshould when boxes are discarded
    
    Returns:
        keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS, 
            sorted in decreasing order of scores
    """
    if boxes.shape[1] == 4:
        # prefer torchvision in 2d because they have c++ cpu version
        nms_fn = nms_2d
    else:
        if boxes.is_cuda:
            nms_fn = nms_gpu
        else:
            nms_fn = nms_cpu
    return nms_fn(boxes.float(), scores.float(), iou_threshold)


def batched_nms(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float):
    """
    Performs non-maximum suppression in a batched fashion.
    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.
    
    Args:
        boxes (Tensor): boxes where NMS will be performed. (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
        scores (Tensor): scores for each one of the boxes [N]
        idxs (Tensor): indices of the categories for each one of the boxes. [N]
        iou_threshold (float):  discards all overlapping boxes with IoU > iou_threshold
    
    Returns
        keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS, 
            sorted in decreasing order of scores
    """
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes.max()
    offsets = idxs.to(boxes) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    return nms(boxes_for_nms, scores, iou_threshold)