customize.py 1.95 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Customized functions"""

import torch
from torch.autograd import Variable, Function
Hang Zhang's avatar
Hang Zhang committed
14
15
16
17

from encoding import cpu
if torch.cuda.device_count() > 0:
    from encoding import gpu
Hang Zhang's avatar
Hang Zhang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

__all__ = ['NonMaxSuppression']

def NonMaxSuppression(boxes, scores, threshold):
    r"""Non-Maximum Suppression
    The algorithm begins by storing the highest-scoring bounding
    box, and eliminating any box whose intersection-over-union (IoU)
    with it is too great. The procedure repeats on the surviving
    boxes, and so on until there are no boxes left.
    The stored boxes are returned.

    NB: The function returns a tuple (mask, indices), where
    indices index into the input boxes and are sorted
    according to score, from higest to lowest.
    indices[i][mask[i]] gives the indices of the surviving
    boxes from the ith batch, sorted by score.

    Args:
      - boxes :math:`(N, n_boxes, 4)`
      - scroes :math:`(N, n_boxes)`
      - threshold (float): IoU above which to eliminate boxes

    Outputs:
      - mask: :math:`(N, n_boxes)`
      - indicies: :math:`(N, n_boxes)`

    Examples::

    >>> boxes = torch.Tensor([[[10., 20., 20., 15.],
    >>>                       [24., 22., 50., 54.],
    >>>                       [10., 21., 20. 14.5]]])
    >>> scores = torch.abs(torch.randn([1, 3]))
    >>> mask, indices = NonMaxSuppression(boxes, scores, 0.7)
    >>> #indices are SORTED according to score.
    >>> surviving_box_indices = indices[mask]
    """
    if boxes.is_cuda:
Hang Zhang's avatar
Hang Zhang committed
55
        return gpu.non_max_suppression(boxes, scores, threshold)
Hang Zhang's avatar
Hang Zhang committed
56
    else:
Hang Zhang's avatar
Hang Zhang committed
57
        return cpu.non_max_suppression(boxes, scores, threshold)