Commit eb66a19d authored by rusty1s's avatar rusty1s
Browse files

compile

parent 6bf96692
...@@ -11,7 +11,7 @@ template <typename scalar_t> ...@@ -11,7 +11,7 @@ template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size, __global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end, const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t D, int64_t numel) { int64_t *out, int64_t D, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) { if (thread_idx < numel) {
int64_t c = 0, k = 1; int64_t c = 0, k = 1;
......
...@@ -64,8 +64,8 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x, ...@@ -64,8 +64,8 @@ torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>( radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(), ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), radius, row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
max_num_neighbors, x.size(1)); x.size(1));
}); });
auto mask = row != -1; auto mask = row != -1;
......
...@@ -23,7 +23,7 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr, ...@@ -23,7 +23,7 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
cur = out[i]; cur = out[i];
row_start = rowptr[cur], row_end = rowptr[cur + 1]; row_start = rowptr[cur], row_end = rowptr[cur + 1];
out[l * numel + n] = out[l * numel + thread_idx] =
col[row_start + int64_t(rand[i] * (row_end - row_start))]; col[row_start + int64_t(rand[i] * (row_end - row_start))];
} }
} }
......
...@@ -23,7 +23,7 @@ except OSError as e: ...@@ -23,7 +23,7 @@ except OSError as e:
raise OSError(e) raise OSError(e)
if torch.version.cuda is not None: # pragma: no cover if torch.version.cuda is not None: # pragma: no cover
cuda_version = torch.ops.torch_sparse.cuda_version() cuda_version = torch.ops.torch_cluster.cuda_version()
if cuda_version == -1: if cuda_version == -1:
major = minor = 0 major = minor = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment