nms_wrapper.py 3.58 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
import numpy as np
import torch

4
from . import nms_cpu, nms_cuda
5
from .soft_nms_cpu import soft_nms_cpu
Kai Chen's avatar
Kai Chen committed
6
7


8
def nms(dets, iou_thr, device_id=None):
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    """Dispatch to either CPU or GPU NMS implementations.

    The input can be either a torch tensor or numpy array. GPU NMS will be used
    if the input is a gpu tensor or device_id is specified, otherwise CPU NMS
    will be used. The returned type will always be the same as inputs.

    Arguments:
        dets (torch.Tensor or np.ndarray): bboxes with scores.
        iou_thr (float): IoU threshold for NMS.
        device_id (int, optional): when `dets` is a numpy array, if `device_id`
            is None, then cpu nms is used, otherwise gpu_nms will be used.

    Returns:
        tuple: kept bboxes and indice, which is always the same data type as
            the input.
24
25
26
27
28
29
30
31
32
33
34
35

    Example:
        >>> dets = np.array([[49.1, 32.4, 51.0, 35.9, 0.9],
        >>>                  [49.3, 32.9, 51.0, 35.3, 0.9],
        >>>                  [49.2, 31.8, 51.0, 35.4, 0.5],
        >>>                  [35.1, 11.5, 39.1, 15.7, 0.5],
        >>>                  [35.6, 11.8, 39.3, 14.2, 0.5],
        >>>                  [35.3, 11.5, 39.9, 14.5, 0.4],
        >>>                  [35.2, 11.7, 39.7, 15.7, 0.3]], dtype=np.float32)
        >>> iou_thr = 0.7
        >>> supressed, inds = nms(dets, iou_thr)
        >>> assert len(inds) == len(supressed) == 3
36
37
    """
    # convert dets (tensor or numpy array) to tensor
Kai Chen's avatar
Kai Chen committed
38
    if isinstance(dets, torch.Tensor):
39
40
        is_numpy = False
        dets_th = dets
41
    elif isinstance(dets, np.ndarray):
42
43
44
        is_numpy = True
        device = 'cpu' if device_id is None else 'cuda:{}'.format(device_id)
        dets_th = torch.from_numpy(dets).to(device)
45
46
47
48
    else:
        raise TypeError(
            'dets must be either a Tensor or numpy array, but got {}'.format(
                type(dets)))
Kai Chen's avatar
Kai Chen committed
49

50
51
52
    # execute cpu or cuda nms
    if dets_th.shape[0] == 0:
        inds = dets_th.new_zeros(0, dtype=torch.long)
Kai Chen's avatar
Kai Chen committed
53
    else:
54
55
56
57
        if dets_th.is_cuda:
            inds = nms_cuda.nms(dets_th, iou_thr)
        else:
            inds = nms_cpu.nms(dets_th, iou_thr)
Kai Chen's avatar
Kai Chen committed
58

59
60
    if is_numpy:
        inds = inds.cpu().numpy()
61
    return dets[inds, :], inds
Kai Chen's avatar
Kai Chen committed
62
63


64
def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
65
66
67
68
69
70
71
72
73
74
75
76
    """
    Example:
        >>> dets = np.array([[4., 3., 5., 3., 0.9],
        >>>                  [4., 3., 5., 4., 0.9],
        >>>                  [3., 1., 3., 1., 0.5],
        >>>                  [3., 1., 3., 1., 0.5],
        >>>                  [3., 1., 3., 1., 0.4],
        >>>                  [3., 1., 3., 1., 0.0]], dtype=np.float32)
        >>> iou_thr = 0.7
        >>> supressed, inds = soft_nms(dets, iou_thr, sigma=0.5)
        >>> assert len(inds) == len(supressed) == 3
    """
Kai Chen's avatar
Kai Chen committed
77
    if isinstance(dets, torch.Tensor):
78
79
80
81
82
        is_tensor = True
        dets_np = dets.detach().cpu().numpy()
    elif isinstance(dets, np.ndarray):
        is_tensor = False
        dets_np = dets
Kai Chen's avatar
Kai Chen committed
83
    else:
84
85
86
        raise TypeError(
            'dets must be either a Tensor or numpy array, but got {}'.format(
                type(dets)))
Kai Chen's avatar
Kai Chen committed
87

88
89
90
    method_codes = {'linear': 1, 'gaussian': 2}
    if method not in method_codes:
        raise ValueError('Invalid method for SoftNMS: {}'.format(method))
91
    new_dets, inds = soft_nms_cpu(
92
93
94
95
96
97
98
99
100
        dets_np,
        iou_thr,
        method=method_codes[method],
        sigma=sigma,
        min_score=min_score)

    if is_tensor:
        return dets.new_tensor(new_dets), dets.new_tensor(
            inds, dtype=torch.long)
Kai Chen's avatar
Kai Chen committed
101
    else:
102
        return new_dets.astype(np.float32), inds.astype(np.int64)