nms_wrapper.py 1.87 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
import numpy as np
import torch

from .gpu_nms import gpu_nms
from .cpu_nms import cpu_nms
from .cpu_soft_nms import cpu_soft_nms


9
def nms(dets, iou_thr, device_id=None):
Kai Chen's avatar
Kai Chen committed
10
11
    """Dispatch to either CPU or GPU NMS implementations."""
    if isinstance(dets, torch.Tensor):
12
        is_tensor = True
Kai Chen's avatar
Kai Chen committed
13
14
        if dets.is_cuda:
            device_id = dets.get_device()
15
16
17
18
19
20
21
22
        dets_np = dets.detach().cpu().numpy()
    elif isinstance(dets, np.ndarray):
        is_tensor = False
        dets_np = dets
    else:
        raise TypeError(
            'dets must be either a Tensor or numpy array, but got {}'.format(
                type(dets)))
Kai Chen's avatar
Kai Chen committed
23

24
    if dets_np.shape[0] == 0:
Kai Chen's avatar
Kai Chen committed
25
26
        inds = []
    else:
27
28
        inds = (gpu_nms(dets_np, iou_thr, device_id=device_id)
                if device_id is not None else cpu_nms(dets_np, iou_thr))
Kai Chen's avatar
Kai Chen committed
29

30
31
    if is_tensor:
        inds = dets.new_tensor(inds, dtype=torch.long)
Kai Chen's avatar
Kai Chen committed
32
    else:
33
34
        inds = np.array(inds, dtype=np.int64)
    return dets[inds, :], inds
Kai Chen's avatar
Kai Chen committed
35
36


37
def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3):
Kai Chen's avatar
Kai Chen committed
38
    if isinstance(dets, torch.Tensor):
39
40
41
42
43
        is_tensor = True
        dets_np = dets.detach().cpu().numpy()
    elif isinstance(dets, np.ndarray):
        is_tensor = False
        dets_np = dets
Kai Chen's avatar
Kai Chen committed
44
    else:
45
46
47
        raise TypeError(
            'dets must be either a Tensor or numpy array, but got {}'.format(
                type(dets)))
Kai Chen's avatar
Kai Chen committed
48

49
50
51
    method_codes = {'linear': 1, 'gaussian': 2}
    if method not in method_codes:
        raise ValueError('Invalid method for SoftNMS: {}'.format(method))
Kai Chen's avatar
Kai Chen committed
52
    new_dets, inds = cpu_soft_nms(
53
54
55
56
57
58
59
60
61
        dets_np,
        iou_thr,
        method=method_codes[method],
        sigma=sigma,
        min_score=min_score)

    if is_tensor:
        return dets.new_tensor(new_dets), dets.new_tensor(
            inds, dtype=torch.long)
Kai Chen's avatar
Kai Chen committed
62
    else:
63
        return new_dets.astype(np.float32), inds.astype(np.int64)