Commit 7157576b authored by rusty1s's avatar rusty1s
Browse files

bugfix, asserts

parent 6f0a5a4d
......@@ -5,7 +5,7 @@
#define THREADS 1024
template <typename scalar_t, int64_t Dim> struct Dist {};
template <typename scalar_t, int64_t Dim> struct Dist;
template <typename scalar_t> struct Dist<scalar_t, 1> {
static __device__ void
......@@ -118,6 +118,7 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
scalar_t best = -1;
ptrdiff_t best_idx = 0;
__syncthreads();
Dist<scalar_t, Dim>::compute(idx, start_idx, end_idx, out[m - 1], &best,
&best_idx, x, dist, tmp_dist, dim);
......
......@@ -27,6 +27,10 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
"""
assert x.is_cuda
assert x.dim() <= 2 and batch.dim() == 1
assert x.size(0) == batch.size(0)
x = x.view(-1, 1) if x.dim() == 1 else x
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment