fps.py 1.92 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
from typing import Optional
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
update  
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
update  
rusty1s committed
6
7
@torch.jit.script
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
8
        ratio: torch.Tensor = torch.tensor(0.5), random_start: bool = True) -> torch.Tensor:
rusty1s's avatar
rusty1s committed
9
10
11
12
    r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
    Learning on Point Sets in a Metric Space"
    <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
    most distant point with regard to the rest points.
rusty1s's avatar
rusty1s committed
13
14

    Args:
rusty1s's avatar
update  
rusty1s committed
15
        src (Tensor): Point feature matrix
rusty1s's avatar
rusty1s committed
16
17
18
19
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
20
        ratio (Tensor, optional): Sampling ratio. (default: :obj:`0.5`)
rusty1s's avatar
rusty1s committed
21
22
        random_start (bool, optional): If set to :obj:`False`, use the first
            node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
docs  
rusty1s committed
24
25
    :rtype: :class:`LongTensor`

rusty1s's avatar
update  
rusty1s committed
26
    .. code-block:: python
rusty1s's avatar
rusty1s committed
27
28
29
30

        import torch
        from torch_cluster import fps

rusty1s's avatar
update  
rusty1s committed
31
32
33
        src = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch = torch.tensor([0, 0, 0, 0])
        index = fps(src, batch, ratio=0.5)
rusty1s's avatar
rusty1s committed
34
35
    """

36
37
38
39
40
    assert len(ratio.shape) < 2, 'Invalid ratio'
    ratio = ratio.to(src.device)
    if len(ratio.shape) == 1:
        assert ratio.shape[0] == int(batch.max()) + 1, 'Mismatched input and ratio numbers'

rusty1s's avatar
update  
rusty1s committed
41
    if batch is not None:
rusty1s's avatar
rusty1s committed
42
        assert src.size(0) == batch.numel()
rusty1s's avatar
update  
rusty1s committed
43
        batch_size = int(batch.max()) + 1
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
update  
rusty1s committed
45
46
        deg = src.new_zeros(batch_size, dtype=torch.long)
        deg.scatter_add_(0, batch, torch.ones_like(batch))
rusty1s's avatar
typos  
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
        ptr = deg.new_zeros(batch_size + 1)
rusty1s's avatar
fix  
rusty1s committed
49
        torch.cumsum(deg, 0, out=ptr[1:])
rusty1s's avatar
rusty1s committed
50
51
    else:
        ptr = torch.tensor([0, src.size(0)], device=src.device)
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
update  
rusty1s committed
53
    return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start)