Commit b0f9f81b authored by rusty1s's avatar rusty1s
Browse files

fps cpu version

parent 0a038334
...@@ -169,7 +169,7 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { ...@@ -169,7 +169,7 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto deg = degree(batch, batch_size); auto deg = degree(batch, batch_size);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto k = (deg.toType(at::kFloat) * ratio).round().toType(at::kLong); auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0); auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
at::Tensor start; at::Tensor start;
......
...@@ -4,12 +4,9 @@ import pytest ...@@ -4,12 +4,9 @@ import pytest
import torch import torch
from torch_cluster import fps from torch_cluster import fps
from .utils import tensor, grad_dtypes from .utils import grad_dtypes, devices, tensor
devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device): def test_fps(dtype, device):
x = tensor([ x = tensor([
...@@ -26,25 +23,3 @@ def test_fps(dtype, device): ...@@ -26,25 +23,3 @@ def test_fps(dtype, device):
out = fps(x, batch, ratio=0.5, random_start=False) out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6] assert out.tolist() == [0, 2, 4, 6]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps_speed(dtype, device):
return
batch_size, num_nodes = 100, 10000
x = torch.randn((batch_size * num_nodes, 3), dtype=dtype, device=device)
batch = torch.arange(batch_size, dtype=torch.long, device=device)
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
out = fps(x, batch, ratio=0.5, random_start=True)
assert out.size(0) == batch_size * num_nodes * 0.5
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes
batch_size, num_nodes, dim = 100, 300, 128
x = torch.randn((batch_size * num_nodes, dim), dtype=dtype, device=device)
batch = torch.arange(batch_size, dtype=torch.long, device=device)
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
out = fps(x, batch, ratio=0.5, random_start=True)
assert out.size(0) == batch_size * num_nodes * 0.5
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes
import torch import torch
import fps_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
import fps_cuda import fps_cuda
...@@ -39,12 +40,11 @@ def fps(x, batch=None, ratio=0.5, random_start=True): ...@@ -39,12 +40,11 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
assert x.is_cuda
assert x.dim() == 2 and batch.dim() == 1 assert x.dim() == 2 and batch.dim() == 1
assert x.size(0) == batch.size(0) assert x.size(0) == batch.size(0)
assert ratio > 0 and ratio < 1 assert ratio > 0 and ratio < 1
op = fps_cuda.fps if x.is_cuda else None if x.is_cuda:
out = op(x, batch, ratio, random_start) return fps_cuda.fps(x, batch, ratio, random_start)
else:
return out return fps_cpu.fps(x, batch, ratio, random_start)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment