fps.py 1.4 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch

if torch.cuda.is_available():
    import fps_cuda
rusty1s's avatar
docs  
rusty1s committed
5
    """    """
rusty1s's avatar
rusty1s committed
6
7
8


def fps(x, batch=None, ratio=0.5, random_start=True):
rusty1s's avatar
rusty1s committed
9
10
    """Iteratively samples the most distant point (in metric distance) with
    regard to the rest points.
rusty1s's avatar
rusty1s committed
11
12

    Args:
rusty1s's avatar
docs  
rusty1s committed
13
14
15
16
17
18
        x (Tensor): D-dimensional point features.
        batch (LongTensor, optional): Vector that maps each point to its
            example identifier. If :obj:`None`, all points belong to the same
            example. If not :obj:`None`, points in the same example need to
            have contiguous memory layout and :obj:`batch` needs to be
            ascending. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
19
20
21
22
        ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)
        random_start (bool, optional): Whether the starting node is
            sampled randomly. (default: :obj:`True`)

rusty1s's avatar
docs  
rusty1s committed
23
24
    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
25
26
27
    Examples::

        >>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
rusty1s's avatar
rusty1s committed
28
29
        >>> batch = torch.tensor([0, 0, 0, 0])
        >>> sample = fps(x, batch, ratio=0.5)
rusty1s's avatar
rusty1s committed
30
31
    """

rusty1s's avatar
rusty1s committed
32
33
34
    if batch is None:
        batch = x.new_zeros(x.size(0), dtype=torch.long)

rusty1s's avatar
typos  
rusty1s committed
35
36
    x = x.view(-1, 1) if x.dim() == 1 else x

rusty1s's avatar
rusty1s committed
37
    assert x.is_cuda
rusty1s's avatar
typos  
rusty1s committed
38
    assert x.dim() == 2 and batch.dim() == 1
rusty1s's avatar
rusty1s committed
39
    assert x.size(0) == batch.size(0)
rusty1s's avatar
assert  
rusty1s committed
40
    assert ratio > 0 and ratio < 1
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
43
44
45
    op = fps_cuda.fps if x.is_cuda else None
    out = op(x, batch, ratio, random_start)

    return out