Unverified Commit 80b99adb authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #57 from rusty1s/wheel

[WIP] Python wheels
parents 0194ebb6 bc476876
#pragma once
#include <torch/extension.h>
torch::Tensor neighbor_sampler_cpu(torch::Tensor start, torch::Tensor rowptr,
int64_t count, double factor);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "fps_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
const int64_t *out_ptr, const int64_t *start,
scalar_t *dist, int64_t *out, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x;
const int64_t start_idx = ptr[batch_idx];
const int64_t end_idx = ptr[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
if (thread_idx == 0) {
out[out_ptr[batch_idx]] = start_idx + start[batch_idx];
}
for (int64_t m = out_ptr[batch_idx] + 1; m < out_ptr[batch_idx + 1]; m++) {
int64_t old = out[m - 1];
scalar_t best = (scalar_t)-1.;
int64_t best_idx = 0;
for (int64_t n = start_idx + thread_idx; n < end_idx; n += THREADS) {
scalar_t tmp;
scalar_t dd = (scalar_t)0.;
for (int64_t d = 0; d < dim; d++) {
tmp = src[dim * old + d] - src[dim * n + d];
dd += tmp * tmp;
}
dist[n] = min(dist[n], dd);
if (dist[n] > best) {
best = dist[n];
best_idx = n;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) {
__syncthreads();
if ((thread_idx + i) < THREADS &&
best_dist[thread_idx] < best_dist[thread_idx + i]) {
best_dist[thread_idx] = best_dist[thread_idx + i];
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
}
}
__syncthreads();
if (thread_idx == 0) {
out[m] = best_dist_idx[0];
}
}
}
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_INPUT(ptr.dim() == 1);
AT_ASSERTM(ratio > 0 && ratio < 1, "Invalid input");
cudaSetDevice(src.get_device());
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.size(0) - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * (float)ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0);
torch::Tensor start;
if (random_start) {
start = torch::rand(batch_size, src.options());
start = (start * deg.toType(torch::kFloat)).toType(torch::kLong);
} else {
start = torch::zeros(batch_size, ptr.options());
}
auto dist = torch::full(src.size(0), 1e38, src.options());
auto out_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto out = torch::empty(out_size[0], out_ptr.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(src.scalar_type(), "fps_kernel", [&] {
fps_kernel<scalar_t><<<batch_size, THREADS, 0, stream>>>(
src.data_ptr<scalar_t>(), ptr.data_ptr<int64_t>(),
out_ptr.data_ptr<int64_t>(), start.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), out.data_ptr<int64_t>(), src.size(1));
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
#include "graclus_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
__global__ void colorize_kernel(int64_t *out, const float *bernoulli,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] < 0) {
out[thread_idx] = (int64_t)bernoulli[thread_idx] - 2;
done_d = false;
}
}
}
bool colorize(torch::Tensor out) {
auto stream = at::cuda::getCurrentCUDAStream();
init_done_kernel<<<1, 1, 0, stream>>>();
auto numel = out.size(0);
auto props = torch::full(numel, BLUE_P, out.options().dtype(torch::kFloat));
auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), bernoulli.data_ptr<float>(), numel);
bool done_h;
cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0,
cudaMemcpyDeviceToHost);
return done_h;
}
__global__ void propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -2) {
proposal[thread_idx] = v; // Propose to first red neighbor.
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -1)
return; // Only vist blue nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
// Find maximum weighted red neighbor.
if (out[v] == -2 && weight[i] >= w_max) {
v_max = v;
w_max = weight[i];
}
}
proposal[thread_idx] = v_max; // Propose.
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void propose(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
propose_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
weighted_propose_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
__global__ void respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr, const int64_t *col,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx) {
// Match first blue neighbhor v which proposed to u.
out[thread_idx] = min(thread_idx, v);
out[v] = min(thread_idx, v);
break;
}
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
template <typename scalar_t>
__global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const int64_t *rowptr,
const int64_t *col,
const scalar_t *weight, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
if (out[thread_idx] != -2)
return; // Only vist red nodes.
bool has_unmatched_neighbor = false;
int64_t v_max = -1;
scalar_t w_max = 0;
for (int64_t i = rowptr[thread_idx]; i < rowptr[thread_idx + 1]; i++) {
auto v = col[i];
if (out[v] < 0)
has_unmatched_neighbor = true; // Unmatched neighbor found.
if (out[v] == -1 && proposal[v] == thread_idx && weight[i] >= w_max) {
// Find maximum weighted blue neighbhor v which proposed to u.
v_max = v;
w_max = weight[i];
}
}
if (v_max >= 0) {
out[thread_idx] = min(thread_idx, v_max); // Match neighbors.
out[v_max] = min(thread_idx, v_max);
}
if (!has_unmatched_neighbor)
out[thread_idx] = thread_idx;
}
}
void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
auto stream = at::cuda::getCurrentCUDAStream();
if (!optional_weight.has_value()) {
respond_kernel<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), out.numel());
} else {
auto weight = optional_weight.value();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
weighted_respond_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
out.data_ptr<int64_t>(), proposal.data_ptr<int64_t>(),
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), out.numel());
});
}
}
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_INPUT(rowptr.dim() == 1 && col.dim() == 1);
if (optional_weight.has_value()) {
CHECK_CUDA(optional_weight.value());
CHECK_INPUT(optional_weight.value().dim() == 1);
CHECK_INPUT(optional_weight.value().numel() == col.numel());
}
cudaSetDevice(rowptr.get_device());
int64_t num_nodes = rowptr.numel() - 1;
auto out = torch::full(num_nodes, -1, rowptr.options());
auto proposal = torch::full(num_nodes, -1, rowptr.options());
while (!colorize(out)) {
propose(out, proposal, rowptr, col, optional_weight);
respond(out, proposal, rowptr, col, optional_weight);
}
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
#include "grid_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t D, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) {
scalar_t p = pos[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
out[thread_idx] = c;
}
}
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
cudaSetDevice(pos.get_device());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
if (optional_start.has_value())
CHECK_CUDA(optional_start.value());
pos = pos.view({pos.size(0), -1}).contiguous();
size = size.contiguous();
CHECK_INPUT(size.numel() == pos.size(1));
if (!optional_start.has_value())
optional_start = std::get<0>(pos.min(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
if (!optional_end.has_value())
optional_end = std::get<0>(pos.max(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
auto start = optional_start.value();
auto end = optional_end.value();
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), pos.size(1), out.numel());
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end);
#include "radius_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * a[i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
scalar_t *dist, int64_t *row, int64_t *col,
int64_t K, int64_t dim, bool cosine) {
const int64_t batch_idx = blockIdx.x;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
for (int64_t k = 0; k < K; k++) {
row[n_y * K + k] = n_y;
}
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist =
Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
Cosine<scalar_t>::dot(x, y, dim);
} else {
for (int64_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
if (dist[n_y * K + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
}
dist[n_y * K + k_idx_1] = tmp_dist;
col[n_y * K + k_idx_1] = n_x;
break;
}
}
}
}
}
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), k, x.size(1), cosine);
});
auto mask = col != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#pragma once
#include <torch/extension.h>
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine);
#include "nearest_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *out, int64_t batch_size, int64_t dim) {
const int64_t thread_idx = threadIdx.x;
const int64_t n_x = blockIdx.x;
int64_t batch_idx;
for (int64_t b = 0; b < batch_size; b++) {
if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
batch_idx = b;
continue;
}
}
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
__shared__ scalar_t best_dist[THREADS];
__shared__ int64_t best_dist_idx[THREADS];
scalar_t best = 1e38;
int64_t best_idx = 0;
for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
n_y += THREADS) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
if (dist < best) {
best = dist;
best_idx = n_y;
}
}
best_dist[thread_idx] = best;
best_dist_idx[thread_idx] = best_idx;
for (int64_t i = 1; i < THREADS; i *= 2) {
__syncthreads();
if ((thread_idx + i) < THREADS &&
best_dist[thread_idx] > best_dist[thread_idx + i]) {
best_dist[thread_idx] = best_dist[thread_idx + i];
best_dist_idx[thread_idx] = best_dist_idx[thread_idx + i];
}
}
__syncthreads();
if (thread_idx == 0) {
out[n_x] = best_dist_idx[0];
}
}
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto out = torch::empty({x.size(0)}, ptr_x.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor ptr_x, torch::Tensor ptr_y);
#include "radius_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t>
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
int64_t *row, int64_t *col, scalar_t radius,
int64_t max_num_neighbors, int64_t dim) {
const int64_t batch_idx = blockIdx.x;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
int64_t count = 0;
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
scalar_t dist = 0;
for (int64_t d = 0; d < dim; d++) {
dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
dist = sqrt(dist);
if (dist <= radius) {
row[n_y * max_num_neighbors + count] = n_y;
col[n_y * max_num_neighbors + count] = n_x;
count++;
}
if (count >= max_num_neighbors) {
break;
}
}
}
}
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
int64_t max_num_neighbors) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto row = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto col = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
x.size(1));
});
auto mask = row != -1;
return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#pragma once
#include <torch/extension.h>
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, double r,
int64_t max_num_neighbors);
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
out[thread_idx] = start[thread_idx];
int64_t row_start, row_end, i, cur;
for (int64_t l = 1; l <= walk_length; l++) {
i = (l - 1) * numel + thread_idx;
cur = out[i];
row_start = rowptr[cur], row_end = rowptr[cur + 1];
out[l * numel + thread_idx] =
col[row_start + int64_t(rand[i] * (row_end - row_start))];
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include <Python.h>
#include <torch/script.h>
#include "cpu/fps_cpu.h"
#ifdef WITH_CUDA
#include "cuda/fps_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__fps(void) { return NULL; }
#endif
torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return fps_cuda(src, ptr, ratio, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, ptr, ratio, random_start);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::fps", &fps);
#include <Python.h>
#include <torch/script.h>
#include "cpu/graclus_cpu.h"
#ifdef WITH_CUDA
#include "cuda/graclus_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__graclus(void) { return NULL; }
#endif
torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight) {
if (rowptr.device().is_cuda()) {
#ifdef WITH_CUDA
return graclus_cuda(rowptr, col, optional_weight);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return graclus_cpu(rowptr, col, optional_weight);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::graclus", &graclus);
#include <Python.h>
#include <torch/script.h>
#include "cpu/grid_cpu.h"
#ifdef WITH_CUDA
#include "cuda/grid_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__grid(void) { return NULL; }
#endif
torch::Tensor grid(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
if (pos.device().is_cuda()) {
#ifdef WITH_CUDA
return grid_cuda(pos, size, optional_start, optional_end);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return grid_cpu(pos, size, optional_start, optional_end);
}
}
static auto registry =
torch::RegisterOperators().op("torch_cluster::grid", &grid);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment