fps.py 1.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

if torch.cuda.is_available():
    import fps_cuda


def fps(x, batch=None, ratio=0.5, random_start=True):
    """A clustering algorithm, which overlays a regular grid of user-defined
    size over a point cloud and clusters all points within a voxel.

    Args:
        x (Tensor): D-dimensional node features.
        batch (LongTensor, optional): Vector that maps each node to a graph.
            If :obj:`None`, all node features belong to the same graph. If not
            :obj:`None`, nodes of the same graph need to have contiguous memory
            layout and :obj:`batch` needs to be ascending.
            (default: :obj:`None`)
        ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)
        random_start (bool, optional): Whether the starting node is
            sampled randomly. (default: :obj:`True`)

    Examples::

        >>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        >>> batch = torch.Tensor([0, 0, 0, 0])
        >>> sample = fps(pos, batch)
    """

rusty1s's avatar
rusty1s committed
29
30
31
    if batch is None:
        batch = x.new_zeros(x.size(0), dtype=torch.long)

rusty1s's avatar
rusty1s committed
32
    assert x.is_cuda
rusty1s's avatar
rusty1s committed
33
34
    assert x.dim() <= 2 and batch.dim() == 1
    assert x.size(0) == batch.size(0)
rusty1s's avatar
assert  
rusty1s committed
35
    assert ratio > 0 and ratio < 1
rusty1s's avatar
rusty1s committed
36
37

    x = x.view(-1, 1) if x.dim() == 1 else x
rusty1s's avatar
rusty1s committed
38
39
40
41
42

    op = fps_cuda.fps if x.is_cuda else None
    out = op(x, batch, ratio, random_start)

    return out