Unverified Commit 7fc01df2 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #85 from justanhduc/master

FPS with one ratio for each point cloud
parents a5abfee8 882c8e08
......@@ -8,20 +8,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).pow_(2).sum(1);
}
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
CHECK_CPU(src);
CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options());
......
......@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start);
......@@ -64,21 +64,21 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
......
......@@ -2,5 +2,5 @@
#include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start);
......@@ -11,7 +11,7 @@
PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
......
......@@ -2,11 +2,17 @@ from itertools import product
import pytest
import torch
from torch import Tensor
from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor
@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device):
x = tensor([
......@@ -21,12 +27,35 @@ def test_fps(dtype, device):
], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
......
from typing import Optional
import torch
from torch import Tensor
@torch.jit.script
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
ratio: float = 0.5, random_start: bool = True) -> torch.Tensor:
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
pass
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
......@@ -17,12 +28,14 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)
ratio (float or Tensor, optional): Sampling ratio.
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
......@@ -33,6 +46,15 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5)
"""
r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None
if batch is not None:
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
......@@ -45,4 +67,4 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
else:
ptr = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start)
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment