boxes.py 5.63 KB
Newer Older
1
import torch
eellison's avatar
eellison committed
2
3
from torch.jit.annotations import Tuple
from torch import Tensor
4
import torchvision
5
6


7
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
8
9
10
11
12
13
14
15
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

Francisco Massa's avatar
Francisco Massa committed
16
17
18
    If multiple boxes have the exact same score and satisfy the IoU
    criterion with respect to a reference box, the selected box is
    not guaranteed to be the same between CPU and GPU. This is similar
19
20
    to the behavior of argsort in PyTorch when repeated values are present.

21
22
23
24
25
26
27
28
29
    Parameters
    ----------
    boxes : Tensor[N, 4])
        boxes to perform NMS on. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    iou_threshold : float
        discards all overlapping
30
        boxes with IoU > iou_threshold
31

32
33
34
35
36
37
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices
        of the elements that have been kept
        by NMS, sorted in decreasing order of scores
38
    """
39
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
40
41


42
@torch.jit._script_if_tracing
43
44
45
46
47
48
def batched_nms(
    boxes: Tensor,
    scores: Tensor,
    idxs: Tensor,
    iou_threshold: float,
) -> Tensor:
49
50
51
52
53
54
    """
    Performs non-maximum suppression in a batched fashion.

    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.

55
56
57
58
59
60
61
62
63
64
65
    Parameters
    ----------
    boxes : Tensor[N, 4]
        boxes where NMS will be performed. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    idxs : Tensor[N]
        indices of the categories for each one of the boxes.
    iou_threshold : float
        discards all overlapping boxes
66
        with IoU > iou_threshold
67

68
69
70
71
72
73
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices of
        the elements that have been kept by NMS, sorted
        in decreasing order of scores
74
    """
75
76
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
77
78
79
80
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
81
82
83
84
85
86
    else:
        max_coordinate = boxes.max()
        offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
        boxes_for_nms = boxes + offsets[:, None]
        keep = nms(boxes_for_nms, scores, iou_threshold)
        return keep
87
88


89
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
90
91
92
93
    """
    Remove boxes which contains at least one side smaller than min_size.

    Arguments:
94
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
eellison's avatar
eellison committed
95
        min_size (float): minimum size
96
97
98
99
100

    Returns:
        keep (Tensor[K]): indices of the boxes that have both sides
            larger than min_size
    """
101
102
103
104
105
106
    ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
    keep = (ws >= min_size) & (hs >= min_size)
    keep = keep.nonzero().squeeze(1)
    return keep


107
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
108
    """
109
110
    Clip boxes so that they lie inside an image of size `size`.

111
    Arguments:
112
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
113
        size (Tuple[height, width]): size of the image
114
115
116
117
118
119
120
121

    Returns:
        clipped_boxes (Tensor[N, 4])
    """
    dim = boxes.dim()
    boxes_x = boxes[..., 0::2]
    boxes_y = boxes[..., 1::2]
    height, width = size
122
123
124
125
126
127
128
129
130
131

    if torchvision._is_tracing():
        boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
    else:
        boxes_x = boxes_x.clamp(min=0, max=width)
        boxes_y = boxes_y.clamp(min=0, max=height)

132
133
134
135
    clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
    return clipped_boxes.reshape(boxes.shape)


136
def box_area(boxes: Tensor) -> Tensor:
137
138
    """
    Computes the area of a set of bounding boxes, which are specified by its
139
    (x1, y1, x2, y2) coordinates.
140
141
142

    Arguments:
        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
143
            are expected to be in (x1, y1, x2, y2) format
144
145
146
147
148
149
150
151
152

    Returns:
        area (Tensor[N]): area for each box
    """
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
153
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
154
155
156
    """
    Return intersection-over-union (Jaccard index) of boxes.

157
158
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    Arguments:
        boxes1 (Tensor[N, 4])
        boxes2 (Tensor[M, 4])

    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in boxes1 and boxes2
    """
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    iou = inter / (area1[:, None] + area2 - inter)
    return iou