boxes.py 9.03 KB
Newer Older
1
import torch
eellison's avatar
eellison committed
2
3
from torch.jit.annotations import Tuple
from torch import Tensor
4
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
5
import torchvision
6
7


8
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
9
10
11
12
13
14
15
16
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

Francisco Massa's avatar
Francisco Massa committed
17
18
19
    If multiple boxes have the exact same score and satisfy the IoU
    criterion with respect to a reference box, the selected box is
    not guaranteed to be the same between CPU and GPU. This is similar
20
21
    to the behavior of argsort in PyTorch when repeated values are present.

22
23
24
25
26
27
28
29
30
    Parameters
    ----------
    boxes : Tensor[N, 4])
        boxes to perform NMS on. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    iou_threshold : float
        discards all overlapping
31
        boxes with IoU > iou_threshold
32

33
34
35
36
37
38
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices
        of the elements that have been kept
        by NMS, sorted in decreasing order of scores
39
    """
40
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
41
42


43
@torch.jit._script_if_tracing
44
45
46
47
48
49
def batched_nms(
    boxes: Tensor,
    scores: Tensor,
    idxs: Tensor,
    iou_threshold: float,
) -> Tensor:
50
51
52
53
54
55
    """
    Performs non-maximum suppression in a batched fashion.

    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.

56
57
58
59
60
61
62
63
64
65
66
    Parameters
    ----------
    boxes : Tensor[N, 4]
        boxes where NMS will be performed. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    idxs : Tensor[N]
        indices of the categories for each one of the boxes.
    iou_threshold : float
        discards all overlapping boxes
67
        with IoU > iou_threshold
68

69
70
71
72
73
74
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices of
        the elements that have been kept by NMS, sorted
        in decreasing order of scores
75
    """
76
77
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
78
79
80
81
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
82
83
84
85
86
87
    else:
        max_coordinate = boxes.max()
        offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
        boxes_for_nms = boxes + offsets[:, None]
        keep = nms(boxes_for_nms, scores, iou_threshold)
        return keep
88
89


90
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
91
92
93
94
    """
    Remove boxes which contains at least one side smaller than min_size.

    Arguments:
95
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
eellison's avatar
eellison committed
96
        min_size (float): minimum size
97
98
99
100
101

    Returns:
        keep (Tensor[K]): indices of the boxes that have both sides
            larger than min_size
    """
102
103
    ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
    keep = (ws >= min_size) & (hs >= min_size)
104
    keep = torch.where(keep)[0]
105
106
107
    return keep


108
def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor:
109
    """
110
111
    Clip boxes so that they lie inside an image of size `size`.

112
    Arguments:
113
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
114
        size (Tuple[height, width]): size of the image
115
116
117
118
119
120
121
122

    Returns:
        clipped_boxes (Tensor[N, 4])
    """
    dim = boxes.dim()
    boxes_x = boxes[..., 0::2]
    boxes_y = boxes[..., 1::2]
    height, width = size
123
124
125
126
127
128
129
130
131
132

    if torchvision._is_tracing():
        boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
    else:
        boxes_x = boxes_x.clamp(min=0, max=width)
        boxes_y = boxes_y.clamp(min=0, max=height)

133
134
135
136
    clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
    return clipped_boxes.reshape(boxes.shape)


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
    """
    Converts boxes from given in_fmt to out_fmt.
    Supported in_fmt and out_fmt are:

    'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right.

    'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.

    'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h
    being width and height.

    Arguments:
        boxes (Tensor[N, 4]): boxes which will be converted.
        in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh'].
        out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh']

    Returns:
        boxes (Tensor[N, 4]): Boxes into converted format.
    """
    allowed_fmts = ("xyxy", "xywh", "cxcywh")
    assert in_fmt in allowed_fmts
    assert out_fmt in allowed_fmts

    if in_fmt == out_fmt:
        boxes_converted = boxes.clone()
        return boxes_converted

    if in_fmt != 'xyxy' and out_fmt != 'xyxy':
        if in_fmt == "xywh":
            boxes_xyxy = _box_xywh_to_xyxy(boxes)
            if out_fmt == "cxcywh":
                boxes_converted = _box_xyxy_to_cxcywh(boxes_xyxy)

        elif in_fmt == "cxcywh":
            boxes_xyxy = _box_cxcywh_to_xyxy(boxes)
            if out_fmt == "xywh":
                boxes_converted = _box_xyxy_to_xywh(boxes_xyxy)

        # convert one to xyxy and change either in_fmt or out_fmt to xyxy
    else:
        if in_fmt == "xyxy":
            if out_fmt == "xywh":
                boxes_converted = _box_xyxy_to_xywh(boxes)
            elif out_fmt == "cxcywh":
                boxes_converted = _box_xyxy_to_cxcywh(boxes)
        elif out_fmt == "xyxy":
            if in_fmt == "xywh":
                boxes_converted = _box_xywh_to_xyxy(boxes)
            elif in_fmt == "cxcywh":
                boxes_converted = _box_cxcywh_to_xyxy(boxes)

    return boxes_converted


192
def box_area(boxes: Tensor) -> Tensor:
193
194
    """
    Computes the area of a set of bounding boxes, which are specified by its
195
    (x1, y1, x2, y2) coordinates.
196
197
198

    Arguments:
        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
199
            are expected to be in (x1, y1, x2, y2) format
200
201
202
203
204
205
206
207
208

    Returns:
        area (Tensor[N]): area for each box
    """
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
209
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
210
211
212
    """
    Return intersection-over-union (Jaccard index) of boxes.

213
214
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

215
216
217
218
219
    Arguments:
        boxes1 (Tensor[N, 4])
        boxes2 (Tensor[M, 4])

    Returns:
Aditya Oke's avatar
Aditya Oke committed
220
        iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
221
222
223
224
225
226
227
228
229
230
231
232
    """
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    iou = inter / (area1[:, None] + area2 - inter)
    return iou
Aditya Oke's avatar
Aditya Oke committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275


# Implementation adapted from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
    """
    Return generalized intersection-over-union (Jaccard index) of boxes.

    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

    Arguments:
        boxes1 (Tensor[N, 4])
        boxes2 (Tensor[M, 4])

    Returns:
        generalized_iou (Tensor[N, M]): the NxM matrix containing the pairwise generalized_IoU values
        for every element in boxes1 and boxes2
    """

    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()

    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    iou = inter / union

    lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

    whi = (rbi - lti).clamp(min=0)  # [N,M,2]
    areai = whi[:, :, 0] * whi[:, :, 1]

    return iou - (areai - union) / areai