Commit 7157576b authored by rusty1s's avatar rusty1s
Browse files

bugfix, asserts

parent 6f0a5a4d
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#define THREADS 1024 #define THREADS 1024
template <typename scalar_t, int64_t Dim> struct Dist {}; template <typename scalar_t, int64_t Dim> struct Dist;
template <typename scalar_t> struct Dist<scalar_t, 1> { template <typename scalar_t> struct Dist<scalar_t, 1> {
static __device__ void static __device__ void
...@@ -118,6 +118,7 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg, ...@@ -118,6 +118,7 @@ fps_kernel(scalar_t *__restrict__ x, int64_t *__restrict__ cum_deg,
scalar_t best = -1; scalar_t best = -1;
ptrdiff_t best_idx = 0; ptrdiff_t best_idx = 0;
__syncthreads();
Dist<scalar_t, Dim>::compute(idx, start_idx, end_idx, out[m - 1], &best, Dist<scalar_t, Dim>::compute(idx, start_idx, end_idx, out[m - 1], &best,
&best_idx, x, dist, tmp_dist, dim); &best_idx, x, dist, tmp_dist, dim);
......
...@@ -27,6 +27,10 @@ def fps(x, batch=None, ratio=0.5, random_start=True): ...@@ -27,6 +27,10 @@ def fps(x, batch=None, ratio=0.5, random_start=True):
""" """
assert x.is_cuda assert x.is_cuda
assert x.dim() <= 2 and batch.dim() == 1
assert x.size(0) == batch.size(0)
x = x.view(-1, 1) if x.dim() == 1 else x
if batch is None: if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long) batch = x.new_zeros(x.size(0), dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment