"examples/sampling/vscode:/vscode.git/clone" did not exist on "80203e34cd21516d2a503c671b0708f0b343ff45"
Unverified Commit 7fc01df2 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #85 from justanhduc/master

FPS with one ratio for each point cloud
parents a5abfee8 882c8e08
...@@ -8,20 +8,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) { ...@@ -8,20 +8,20 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).pow_(2).sum(1); return (x - x[idx]).pow_(2).sum(1);
} }
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(ptr); CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1; auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options()); auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options());
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start); bool random_start);
...@@ -64,21 +64,21 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -64,21 +64,21 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
} }
} }
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
bool random_start) { torch::Tensor ratio, bool random_start) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_INPUT(ptr.dim() == 1); CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous(); src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous(); ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1; auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0); out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
bool random_start); torch::Tensor ratio, bool random_start);
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
PyMODINIT_FUNC PyInit__fps(void) { return NULL; } PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif #endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) { bool random_start) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
......
...@@ -2,11 +2,17 @@ from itertools import product ...@@ -2,11 +2,17 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch import Tensor
from torch_cluster import fps from torch_cluster import fps
from .utils import grad_dtypes, devices, tensor from .utils import grad_dtypes, devices, tensor
@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device): def test_fps(dtype, device):
x = tensor([ x = tensor([
...@@ -21,12 +27,35 @@ def test_fps(dtype, device): ...@@ -21,12 +27,35 @@ def test_fps(dtype, device):
], dtype, device) ], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=0.5, random_start=False) out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor(0.5, device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device),
random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=0.5, random_start=False) out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7] assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_random_fps(device): def test_random_fps(device):
......
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor
@torch.jit.script @torch.jit._overload # noqa
def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, def fps(src, batch=None, ratio=None, random_start=True):
ratio: float = 0.5, random_start: bool = True) -> torch.Tensor: # type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
pass
@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True):
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
pass
def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space" Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
...@@ -17,12 +28,14 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -17,12 +28,14 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
batch (LongTensor, optional): Batch vector batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`) node to a specific example. (default: :obj:`None`)
ratio (float, optional): Sampling ratio. (default: :obj:`0.5`) ratio (float or Tensor, optional): Sampling ratio.
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
:rtype: :class:`LongTensor` :rtype: :class:`LongTensor`
.. code-block:: python .. code-block:: python
import torch import torch
...@@ -33,6 +46,15 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -33,6 +46,15 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5) index = fps(src, batch, ratio=0.5)
""" """
r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None
if batch is not None: if batch is not None:
assert src.size(0) == batch.numel() assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1 batch_size = int(batch.max()) + 1
...@@ -45,4 +67,4 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None, ...@@ -45,4 +67,4 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
else: else:
ptr = torch.tensor([0, src.size(0)], device=src.device) ptr = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr, ratio, random_start) return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment