"vscode:/vscode.git/clone" did not exist on "17393cb79aeef5cacedbf07da5dcaad9b63367e6"
Commit 9c33077e authored by rusty1s's avatar rusty1s
Browse files

fix fps implementation

parent aff91e0e
......@@ -35,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, double ratio,
int64_t start_idx = 0;
if (random_start) {
start_idx = rand() % src.size(0);
start_idx = rand() % y.size(0);
}
out_data[out_start] = src_start + start_idx;
......
......@@ -4,7 +4,7 @@
#include "utils.cuh"
#define THREADS 1024
#define THREADS 256
template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
......@@ -31,15 +31,15 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
int64_t best_idx = 0;
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp;
scalar_t dd = (scalar_t)0.;
scalar_t tmp, dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp;
}
dist[n] = min(dist[n], dd);
if (dist[n] > best) {
best = dist[n];
dd = min(dist[n], dd);
dist[n] = dd;
if (dd > best) {
best = dd;
best_idx = n;
}
}
......
......@@ -63,7 +63,7 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_cluster',
version='1.5.2',
version='1.5.3',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_cluster',
......
......@@ -26,3 +26,15 @@ def test_fps(dtype, device):
out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
N = 1024
for _ in range(5):
pos = torch.randn((2 * N, 3), device=device)
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N
......@@ -3,7 +3,7 @@ import os.path as osp
import torch
__version__ = '1.5.2'
__version__ = '1.5.3'
expected_torch_version = (1, 4)
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment