fps.py 1.54 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
import torch

if torch.cuda.is_available():
    import fps_cuda


def fps(x, batch=None, ratio=0.5, random_start=True):
rusty1s's avatar
rusty1s committed
8
9
10
11
    r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
    Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
    most distant point with regard to the rest points.
rusty1s's avatar
rusty1s committed
12
13

    Args:
rusty1s's avatar
rusty1s committed
14
15
16
17
18
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
rusty1s's avatar
rusty1s committed
19
        ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)
rusty1s's avatar
rusty1s committed
20
21
        random_start (bool, optional): If set to :obj:`False`, use the first
            node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
docs  
rusty1s committed
23
24
    :rtype: :class:`LongTensor`

rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
    .. testsetup::

        import torch
        from torch_cluster import fps

    .. testcode::
rusty1s's avatar
rusty1s committed
31
32

        >>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
rusty1s's avatar
rusty1s committed
33
        >>> batch = torch.tensor([0, 0, 0, 0])
rusty1s's avatar
rusty1s committed
34
        >>> index = fps(x, batch, ratio=0.5)
rusty1s's avatar
rusty1s committed
35
36
    """

rusty1s's avatar
rusty1s committed
37
38
39
    if batch is None:
        batch = x.new_zeros(x.size(0), dtype=torch.long)

rusty1s's avatar
typos  
rusty1s committed
40
41
    x = x.view(-1, 1) if x.dim() == 1 else x

rusty1s's avatar
rusty1s committed
42
    assert x.is_cuda
rusty1s's avatar
typos  
rusty1s committed
43
    assert x.dim() == 2 and batch.dim() == 1
rusty1s's avatar
rusty1s committed
44
    assert x.size(0) == batch.size(0)
rusty1s's avatar
assert  
rusty1s committed
45
    assert ratio > 0 and ratio < 1
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
49
50
    op = fps_cuda.fps if x.is_cuda else None
    out = op(x, batch, ratio, random_start)

    return out