Commit c11f343b authored by rusty1s's avatar rusty1s
Browse files

first tests

parent 7157576b
...@@ -15,7 +15,7 @@ template <typename scalar_t> struct Dist<scalar_t, 1> { ...@@ -15,7 +15,7 @@ template <typename scalar_t> struct Dist<scalar_t, 1> {
scalar_t *__restrict__ tmp_dist, size_t dim) { scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
scalar_t d = x[old * 3 + 0] - x[n * 3 + 0]; scalar_t d = x[old] - x[n];
dist[n] = min(dist[n], d * d); dist[n] = min(dist[n], d * d);
if (dist[n] > *best) { if (dist[n] > *best) {
*best = dist[n]; *best = dist[n];
...@@ -33,8 +33,8 @@ template <typename scalar_t> struct Dist<scalar_t, 2> { ...@@ -33,8 +33,8 @@ template <typename scalar_t> struct Dist<scalar_t, 2> {
scalar_t *__restrict__ tmp_dist, size_t dim) { scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
scalar_t a = x[old * 3 + 0] - x[n * 3 + 0]; scalar_t a = x[2 * old + 0] - x[2 * n + 0];
scalar_t b = x[old * 3 + 1] - x[n * 3 + 1]; scalar_t b = x[2 * old + 1] - x[2 * n + 1];
scalar_t d = a * a + b * b; scalar_t d = a * a + b * b;
dist[n] = min(dist[n], d); dist[n] = min(dist[n], d);
if (dist[n] > *best) { if (dist[n] > *best) {
...@@ -53,9 +53,9 @@ template <typename scalar_t> struct Dist<scalar_t, 3> { ...@@ -53,9 +53,9 @@ template <typename scalar_t> struct Dist<scalar_t, 3> {
scalar_t *__restrict__ tmp_dist, size_t dim) { scalar_t *__restrict__ tmp_dist, size_t dim) {
for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) { for (ptrdiff_t n = start_idx + idx; n < end_idx; n += THREADS) {
scalar_t a = x[old * 3 + 0] - x[n * 3 + 0]; scalar_t a = x[3 * old + 0] - x[3 * n + 0];
scalar_t b = x[old * 3 + 1] - x[n * 3 + 1]; scalar_t b = x[3 * old + 1] - x[3 * n + 1];
scalar_t c = x[old * 3 + 2] - x[n * 3 + 2]; scalar_t c = x[3 * old + 2] - x[3 * n + 2];
scalar_t d = a * a + b * b + c * c; scalar_t d = a * a + b * b + c * c;
dist[n] = min(dist[n], d); dist[n] = min(dist[n], d);
if (dist[n] > *best) { if (dist[n] > *best) {
......
...@@ -2,15 +2,14 @@ from itertools import product ...@@ -2,15 +2,14 @@ from itertools import product
import pytest import pytest
import torch import torch
import fps_cuda from torch_cluster import fps
from .utils import tensor from .utils import tensor, grad_dtypes
dtypes = [torch.float]
devices = [torch.device('cuda')] devices = [torch.device('cuda')]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device): def test_fps(dtype, device):
x = tensor([ x = tensor([
[-1, -1], [-1, -1],
...@@ -24,5 +23,26 @@ def test_fps(dtype, device): ...@@ -24,5 +23,26 @@ def test_fps(dtype, device):
], dtype, device) ], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
out = fps_cuda.fps(x, batch, 0.5, False) out = fps(x, batch, ratio=0.5, random_start=False)
print(out) assert out.tolist() == [0, 2, 4, 6]
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps_speed(dtype, device):
return
batch_size, num_nodes = 100, 10000
x = torch.randn((batch_size * num_nodes, 3), dtype=dtype, device=device)
batch = torch.arange(batch_size, dtype=torch.long, device=device)
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
out = fps(x, batch, ratio=0.5, random_start=True)
assert out.size(0) == batch_size * num_nodes * 0.5
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes
batch_size, num_nodes, dim = 100, 300, 128
x = torch.randn((batch_size * num_nodes, dim), dtype=dtype, device=device)
batch = torch.arange(batch_size, dtype=torch.long, device=device)
batch = batch.view(-1, 1).repeat(1, num_nodes).view(-1)
out = fps(x, batch, ratio=0.5, random_start=True)
assert out.size(0) == batch_size * num_nodes * 0.5
assert out.min().item() >= 0 and out.max().item() < batch_size * num_nodes
...@@ -4,6 +4,8 @@ from torch.testing import get_all_dtypes ...@@ -4,6 +4,8 @@ from torch.testing import get_all_dtypes
dtypes = get_all_dtypes() dtypes = get_all_dtypes()
dtypes.remove(torch.half) dtypes.remove(torch.half)
grad_dtypes = [torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))] devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment