from typing import Optional import torch @torch.jit.script def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ratio: float = 0.5, random_start: bool = True) -> torch.Tensor: r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper, which iteratively samples the most distant point with regard to the rest points. Args: src (Tensor): Point feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. batch (LongTensor, optional): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) ratio (float, optional): Sampling ratio. (default: :obj:`0.5`) random_start (bool, optional): If set to :obj:`False`, use the first node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) :rtype: :class:`LongTensor` .. code-block:: python import torch from torch_cluster import fps src = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) batch = torch.tensor([0, 0, 0, 0]) index = fps(src, batch, ratio=0.5) """ ptr: Optional[torch.Tensor] = None if batch is not None: assert src.size(0) == batch.size(0) batch_size = int(batch.max()) + 1 deg = src.new_zeros(batch_size, dtype=torch.long) deg.scatter_add_(0, batch, torch.ones_like(batch)) ptr = src.new_zeros(batch_size + 1, dtype=torch.long) deg.cumsum(0, out=ptr[1:]) return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start)