Commit 9c33077e authored by rusty1s's avatar rusty1s
Browse files

fix fps implementation

parent aff91e0e
...@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio, ...@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
int64_t start_idx = 0; int64_t start_idx = 0;
if (random_start) { if (random_start) {
start_idx = rand() % src.size(0); start_idx = rand() % y.size(0);
} }
out_data[out_start] = src_start + start_idx; out_data[out_start] = src_start + start_idx;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 256
template <typename scalar_t> template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
...@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, ...@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t best_idx = 0; int64_t best_idx = 0;
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) { for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp; scalar_t tmp, dd = (scalar_t)0.;
scalar_t dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) { for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d]; tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp; dd += tmp * tmp;
} }
dist[n] = min(dist[n], dd); dd = min(dist[n], dd);
if (dist[n] > best) { dist[n] = dd;
best = dist[n]; if (dd > best) {
best = dd;
best_idx = n; best_idx = n;
} }
} }
......
...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_cluster', name='torch_cluster',
version='1.5.2', version='1.5.3',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster', url='https://github.com/rusty1s/pytorch_cluster',
......
...@@ -26,3 +26,15 @@ def test_fps(dtype, device): ...@@ -26,3 +26,15 @@ def test_fps(dtype, device):
out = fps(x, ratio=0.5, random_start=False) out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7] assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
N = 1024
for _ in range(5):
pos = torch.randn((2 * N, 3), device=device)
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '1.5.2' __version__ = '1.5.3'
expected_torch_version = (1, 4) expected_torch_version = (1, 4)
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment