boxes.py 5.22 KB
Newer Older
1
import torch
eellison's avatar
eellison committed
2
3
from torch.jit.annotations import Tuple
from torch import Tensor
4
import torchvision
5
6
7


def nms(boxes, scores, iou_threshold):
eellison's avatar
eellison committed
8
    # type: (Tensor, Tensor, float)
9
10
11
12
13
14
15
16
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

17
18
19
20
21
22
23
24
25
    Parameters
    ----------
    boxes : Tensor[N, 4])
        boxes to perform NMS on. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    iou_threshold : float
        discards all overlapping
26
        boxes with IoU > iou_threshold
27

28
29
30
31
32
33
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices
        of the elements that have been kept
        by NMS, sorted in decreasing order of scores
34
    """
35
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
36
37
38


def batched_nms(boxes, scores, idxs, iou_threshold):
eellison's avatar
eellison committed
39
    # type: (Tensor, Tensor, Tensor, float)
40
41
42
43
44
45
    """
    Performs non-maximum suppression in a batched fashion.

    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.

46
47
48
49
50
51
52
53
54
55
56
    Parameters
    ----------
    boxes : Tensor[N, 4]
        boxes where NMS will be performed. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    idxs : Tensor[N]
        indices of the categories for each one of the boxes.
    iou_threshold : float
        discards all overlapping boxes
57
        with IoU > iou_threshold
58

59
60
61
62
63
64
    Returns
    -------
    keep : Tensor
        int64 tensor with the indices of
        the elements that have been kept by NMS, sorted
        in decreasing order of scores
65
    """
66
67
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
68
69
70
71
72
73
74
75
76
77
78
79
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes.max()
    offsets = idxs.to(boxes) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    keep = nms(boxes_for_nms, scores, iou_threshold)
    return keep


def remove_small_boxes(boxes, min_size):
eellison's avatar
eellison committed
80
    # type: (Tensor, float)
81
82
83
84
    """
    Remove boxes which contains at least one side smaller than min_size.

    Arguments:
85
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
eellison's avatar
eellison committed
86
        min_size (float): minimum size
87
88
89
90
91

    Returns:
        keep (Tensor[K]): indices of the boxes that have both sides
            larger than min_size
    """
92
93
94
95
96
97
98
    ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
    keep = (ws >= min_size) & (hs >= min_size)
    keep = keep.nonzero().squeeze(1)
    return keep


def clip_boxes_to_image(boxes, size):
eellison's avatar
eellison committed
99
    # type: (Tensor, Tuple[int, int])
100
    """
101
102
    Clip boxes so that they lie inside an image of size `size`.

103
    Arguments:
104
        boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
105
        size (Tuple[height, width]): size of the image
106
107
108
109
110
111
112
113

    Returns:
        clipped_boxes (Tensor[N, 4])
    """
    dim = boxes.dim()
    boxes_x = boxes[..., 0::2]
    boxes_y = boxes[..., 1::2]
    height, width = size
114
115
116
117
118
119
120
121
122
123

    if torchvision._is_tracing():
        boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
        boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
    else:
        boxes_x = boxes_x.clamp(min=0, max=width)
        boxes_y = boxes_y.clamp(min=0, max=height)

124
125
126
127
128
129
130
    clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
    return clipped_boxes.reshape(boxes.shape)


def box_area(boxes):
    """
    Computes the area of a set of bounding boxes, which are specified by its
131
    (x1, y1, x2, y2) coordinates.
132
133
134

    Arguments:
        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
135
            are expected to be in (x1, y1, x2, y2) format
136
137
138
139
140
141
142
143
144
145
146
147
148

    Returns:
        area (Tensor[N]): area for each box
    """
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def box_iou(boxes1, boxes2):
    """
    Return intersection-over-union (Jaccard index) of boxes.

149
150
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    Arguments:
        boxes1 (Tensor[N, 4])
        boxes2 (Tensor[M, 4])

    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise
            IoU values for every element in boxes1 and boxes2
    """
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    iou = inter / (area1[:, None] + area2 - inter)
    return iou