Commit de0216d8 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.3 support

parent bd3ae685
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#include "utils.h" #include "utils.h"
at::Tensor get_dist(at::Tensor x, ptrdiff_t index) { at::Tensor get_dist(at::Tensor x, ptrdiff_t index) {
...@@ -7,19 +8,19 @@ at::Tensor get_dist(at::Tensor x, ptrdiff_t index) { ...@@ -7,19 +8,19 @@ at::Tensor get_dist(at::Tensor x, ptrdiff_t index) {
} }
at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) { at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto batch_size = batch[-1].data<int64_t>()[0] + 1; auto batch_size = batch[-1].DATA_PTR<int64_t>()[0] + 1;
auto deg = degree(batch, batch_size); auto deg = degree(batch, batch_size);
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0); auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong); auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0); auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
auto out = at::empty(cum_k[-1].data<int64_t>()[0], batch.options()); auto out = at::empty(cum_k[-1].DATA_PTR<int64_t>()[0], batch.options());
auto cum_deg_d = cum_deg.data<int64_t>(); auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
auto k_d = k.data<int64_t>(); auto k_d = k.DATA_PTR<int64_t>();
auto cum_k_d = cum_k.data<int64_t>(); auto cum_k_d = cum_k.DATA_PTR<int64_t>();
auto out_d = out.data<int64_t>(); auto out_d = out.DATA_PTR<int64_t>();
for (ptrdiff_t b = 0; b < batch_size; b++) { for (ptrdiff_t b = 0; b < batch_size; b++) {
auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options()); auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options());
...@@ -27,14 +28,14 @@ at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) { ...@@ -27,14 +28,14 @@ at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
ptrdiff_t start = 0; ptrdiff_t start = 0;
if (random) { if (random) {
start = at::randperm(y.size(0), batch.options()).data<int64_t>()[0]; start = at::randperm(y.size(0), batch.options()).DATA_PTR<int64_t>()[0];
} }
out_d[cum_k_d[b]] = cum_deg_d[b] + start; out_d[cum_k_d[b]] = cum_deg_d[b] + start;
auto dist = get_dist(y, start); auto dist = get_dist(y, start);
for (ptrdiff_t i = 1; i < k_d[b]; i++) { for (ptrdiff_t i = 1; i < k_d[b]; i++) {
ptrdiff_t argmax = dist.argmax().data<int64_t>()[0]; ptrdiff_t argmax = dist.argmax().DATA_PTR<int64_t>()[0];
out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax; out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax;
dist = at::min(dist, get_dist(y, argmax)); dist = at::min(dist, get_dist(y, argmax));
} }
......
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#include "utils.h" #include "utils.h"
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = rand(row, col); std::tie(row, col) = rand(row, col);
std::tie(row, col) = to_csr(row, col, num_nodes); std::tie(row, col) = to_csr(row, col, num_nodes);
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>(); auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
auto perm = at::randperm(num_nodes, row.options()); auto perm = at::randperm(num_nodes, row.options());
auto perm_data = perm.data<int64_t>(); auto perm_data = perm.DATA_PTR<int64_t>();
auto cluster = at::full(num_nodes, -1, row.options()); auto cluster = at::full(num_nodes, -1, row.options());
auto cluster_data = cluster.data<int64_t>(); auto cluster_data = cluster.DATA_PTR<int64_t>();
for (int64_t i = 0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
auto u = perm_data[i]; auto u = perm_data[i];
...@@ -41,16 +42,16 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight, ...@@ -41,16 +42,16 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int64_t num_nodes) { int64_t num_nodes) {
std::tie(row, col, weight) = remove_self_loops(row, col, weight); std::tie(row, col, weight) = remove_self_loops(row, col, weight);
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes); std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>(); auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
auto perm = at::randperm(num_nodes, row.options()); auto perm = at::randperm(num_nodes, row.options());
auto perm_data = perm.data<int64_t>(); auto perm_data = perm.DATA_PTR<int64_t>();
auto cluster = at::full(num_nodes, -1, row.options()); auto cluster = at::full(num_nodes, -1, row.options());
auto cluster_data = cluster.data<int64_t>(); auto cluster_data = cluster.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] { AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
auto weight_data = weight.data<scalar_t>(); auto weight_data = weight.DATA_PTR<scalar_t>();
for (int64_t i = 0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
auto u = perm_data[i]; auto u = perm_data[i];
......
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#include "utils.h" #include "utils.h"
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
...@@ -12,12 +13,12 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start, ...@@ -12,12 +13,12 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
auto out = auto out =
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options()); at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
auto deg_d = deg.data<int64_t>(); auto deg_d = deg.DATA_PTR<int64_t>();
auto cum_deg_d = cum_deg.data<int64_t>(); auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
auto col_d = col.data<int64_t>(); auto col_d = col.DATA_PTR<int64_t>();
auto start_d = start.data<int64_t>(); auto start_d = start.DATA_PTR<int64_t>();
auto rand_d = rand.data<float>(); auto rand_d = rand.DATA_PTR<float>();
auto out_d = out.data<int64_t>(); auto out_d = out.DATA_PTR<int64_t>();
for (ptrdiff_t n = 0; n < start.size(0); n++) { for (ptrdiff_t n = 0; n < start.size(0); n++) {
int64_t cur = start_d[n]; int64_t cur = start_d[n];
......
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size, at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
float factor) { float factor) {
auto start_ptr = start.data<int64_t>(); auto start_ptr = start.DATA_PTR<int64_t>();
auto cumdeg_ptr = cumdeg.data<int64_t>(); auto cumdeg_ptr = cumdeg.DATA_PTR<int64_t>();
std::vector<int64_t> e_ids; std::vector<int64_t> e_ids;
for (ptrdiff_t i = 0; i < start.size(0); i++) { for (ptrdiff_t i = 0; i < start.size(0); i++) {
...@@ -29,7 +31,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size, ...@@ -29,7 +31,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
e_ids.insert(e_ids.end(), v.begin(), v.end()); e_ids.insert(e_ids.end(), v.begin(), v.end());
} else { } else {
auto sample = at::randperm(num_neighbors, start.options()); auto sample = at::randperm(num_neighbors, start.options());
auto sample_ptr = sample.data<int64_t>(); auto sample_ptr = sample.DATA_PTR<int64_t>();
for (size_t j = 0; j < size_i; j++) { for (size_t j = 0; j < size_i; j++) {
e_ids.push_back(sample_ptr[j] + low); e_ids.push_back(sample_ptr[j] + low);
} }
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -30,8 +32,8 @@ int64_t colorize(at::Tensor cluster) { ...@@ -30,8 +32,8 @@ int64_t colorize(at::Tensor cluster) {
auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat)); auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat));
auto bernoulli = props.bernoulli(); auto bernoulli = props.bernoulli();
colorize_kernel<<<BLOCKS(numel), THREADS>>>(cluster.data<int64_t>(), colorize_kernel<<<BLOCKS(numel), THREADS>>>(
bernoulli.data<float>(), numel); cluster.DATA_PTR<int64_t>(), bernoulli.DATA_PTR<float>(), numel);
int64_t out; int64_t out;
cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost); cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost);
......
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "atomics.cuh" #include "atomics.cuh"
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -164,7 +165,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg, ...@@ -164,7 +165,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1; auto batch_size = batch_sizes[0] + 1;
...@@ -185,15 +186,15 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) { ...@@ -185,15 +186,15 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
auto tmp_dist = at::empty(x.size(0), x.options()); auto tmp_dist = at::empty(x.size(0), x.options());
auto k_sum = (int64_t *)malloc(sizeof(int64_t)); auto k_sum = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(k_sum, cum_k[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(k_sum, cum_k[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto out = at::empty(k_sum[0], k.options()); auto out = at::empty(k_sum[0], k.options());
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "fps_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "fps_kernel", [&] {
FPS_KERNEL(x.size(1), x.data<scalar_t>(), cum_deg.data<int64_t>(), FPS_KERNEL(x.size(1), x.DATA_PTR<scalar_t>(), cum_deg.DATA_PTR<int64_t>(),
cum_k.data<int64_t>(), start.data<int64_t>(), cum_k.DATA_PTR<int64_t>(), start.DATA_PTR<int64_t>(),
dist.data<scalar_t>(), tmp_dist.data<scalar_t>(), dist.DATA_PTR<scalar_t>(), tmp_dist.DATA_PTR<scalar_t>(),
out.data<int64_t>()); out.DATA_PTR<int64_t>());
}); });
return out; return out;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -31,10 +33,10 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, ...@@ -31,10 +33,10 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>( grid_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), cluster.DATA_PTR<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(pos), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(pos),
size.data<scalar_t>(), start.data<scalar_t>(), end.data<scalar_t>(), size.DATA_PTR<scalar_t>(), start.DATA_PTR<scalar_t>(),
cluster.numel()); end.DATA_PTR<scalar_t>(), cluster.numel());
}); });
return cluster; return cluster;
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -79,7 +80,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, ...@@ -79,7 +80,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y, bool cosine) { at::Tensor batch_y, bool cosine) {
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1; auto batch_size = batch_sizes[0] + 1;
...@@ -94,9 +95,10 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x, ...@@ -94,9 +95,10 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<batch_size, THREADS>>>( knn_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(), x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(), batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
col.data<int64_t>(), k, x.size(1), cosine); dist.DATA_PTR<scalar_t>(), row.DATA_PTR<int64_t>(),
col.DATA_PTR<int64_t>(), k, x.size(1), cosine);
}); });
auto mask = col != -1; auto mask = col != -1;
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -62,7 +63,7 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, ...@@ -62,7 +63,7 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
at::Tensor batch_y) { at::Tensor batch_y) {
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1; auto batch_size = batch_sizes[0] + 1;
...@@ -73,8 +74,9 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x, ...@@ -73,8 +74,9 @@ at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
nearest_kernel<scalar_t><<<x.size(0), THREADS>>>( nearest_kernel<scalar_t><<<x.size(0), THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(), x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
batch_y.data<int64_t>(), out.data<int64_t>(), x.size(1)); batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
out.DATA_PTR<int64_t>(), x.size(1));
}); });
return out; return out;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -36,8 +38,8 @@ __global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal, ...@@ -36,8 +38,8 @@ __global__ void propose_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row, void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) { at::Tensor col) {
propose_kernel<<<BLOCKS(cluster.numel()), THREADS>>>( propose_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(), cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
col.data<int64_t>(), cluster.numel()); row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -79,7 +81,8 @@ void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row, ...@@ -79,7 +81,8 @@ void propose(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) { at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] { AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "propose_kernel", [&] {
propose_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>( propose_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(), cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
col.data<int64_t>(), weight.data<scalar_t>(), cluster.numel()); row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
weight.DATA_PTR<scalar_t>(), cluster.numel());
}); });
} }
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -50,7 +51,7 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius, ...@@ -50,7 +51,7 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
size_t max_num_neighbors) { size_t max_num_neighbors) {
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t)); auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t), cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1; auto batch_size = batch_sizes[0] + 1;
...@@ -64,9 +65,10 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius, ...@@ -64,9 +65,10 @@ at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
radius_kernel<scalar_t><<<batch_size, THREADS>>>( radius_kernel<scalar_t><<<batch_size, THREADS>>>(
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(), x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
batch_y.data<int64_t>(), row.data<int64_t>(), col.data<int64_t>(), batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
radius, max_num_neighbors, x.size(1)); row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), radius,
max_num_neighbors, x.size(1));
}); });
auto mask = row != -1; auto mask = row != -1;
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
...@@ -38,8 +40,8 @@ __global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal, ...@@ -38,8 +40,8 @@ __global__ void respond_kernel(int64_t *__restrict__ cluster, int64_t *proposal,
void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row, void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col) { at::Tensor col) {
respond_kernel<<<BLOCKS(cluster.numel()), THREADS>>>( respond_kernel<<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(), cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
col.data<int64_t>(), cluster.numel()); row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), cluster.numel());
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -84,7 +86,8 @@ void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row, ...@@ -84,7 +86,8 @@ void respond(at::Tensor cluster, at::Tensor proposal, at::Tensor row,
at::Tensor col, at::Tensor weight) { at::Tensor col, at::Tensor weight) {
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] { AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "respond_kernel", [&] {
respond_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>( respond_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.data<int64_t>(), proposal.data<int64_t>(), row.data<int64_t>(), cluster.DATA_PTR<int64_t>(), proposal.DATA_PTR<int64_t>(),
col.data<int64_t>(), weight.data<scalar_t>(), cluster.numel()); row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(),
weight.DATA_PTR<scalar_t>(), cluster.numel());
}); });
} }
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
...@@ -37,9 +38,9 @@ at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start, ...@@ -37,9 +38,9 @@ at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options()); at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options());
uniform_rw_kernel<<<BLOCKS(start.numel()), THREADS>>>( uniform_rw_kernel<<<BLOCKS(start.numel()), THREADS>>>(
row.data<int64_t>(), col.data<int64_t>(), deg.data<int64_t>(), row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), deg.DATA_PTR<int64_t>(),
start.data<int64_t>(), rand.data<float>(), out.data<int64_t>(), start.DATA_PTR<int64_t>(), rand.DATA_PTR<float>(),
walk_length, start.numel()); out.DATA_PTR<int64_t>(), walk_length, start.numel());
return out.t().contiguous(); return out.t().contiguous();
} }
...@@ -2,34 +2,52 @@ from setuptools import setup, find_packages ...@@ -2,34 +2,52 @@ from setuptools import setup, find_packages
import torch import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
extra_compile_args = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp']), CppExtension('torch_cluster.graclus_cpu', ['cpu/graclus.cpp'],
extra_compile_args=extra_compile_args),
CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']), CppExtension('torch_cluster.grid_cpu', ['cpu/grid.cpp']),
CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp']), CppExtension('torch_cluster.fps_cpu', ['cpu/fps.cpp'],
CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp']), extra_compile_args=extra_compile_args),
CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp']), CppExtension('torch_cluster.rw_cpu', ['cpu/rw.cpp'],
extra_compile_args=extra_compile_args),
CppExtension('torch_cluster.sampler_cpu', ['cpu/sampler.cpp'],
extra_compile_args=extra_compile_args),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None: if CUDA_HOME is not None:
ext_modules += [ ext_modules += [
CUDAExtension('torch_cluster.graclus_cuda', CUDAExtension('torch_cluster.graclus_cuda',
['cuda/graclus.cpp', 'cuda/graclus_kernel.cu']), ['cuda/graclus.cpp', 'cuda/graclus_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.grid_cuda', CUDAExtension('torch_cluster.grid_cuda',
['cuda/grid.cpp', 'cuda/grid_kernel.cu']), ['cuda/grid.cpp', 'cuda/grid_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.fps_cuda', CUDAExtension('torch_cluster.fps_cuda',
['cuda/fps.cpp', 'cuda/fps_kernel.cu']), ['cuda/fps.cpp', 'cuda/fps_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.nearest_cuda', CUDAExtension('torch_cluster.nearest_cuda',
['cuda/nearest.cpp', 'cuda/nearest_kernel.cu']), ['cuda/nearest.cpp', 'cuda/nearest_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.knn_cuda', CUDAExtension('torch_cluster.knn_cuda',
['cuda/knn.cpp', 'cuda/knn_kernel.cu']), ['cuda/knn.cpp', 'cuda/knn_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.radius_cuda', CUDAExtension('torch_cluster.radius_cuda',
['cuda/radius.cpp', 'cuda/radius_kernel.cu']), ['cuda/radius.cpp', 'cuda/radius_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension('torch_cluster.rw_cuda', CUDAExtension('torch_cluster.rw_cuda',
['cuda/rw.cpp', 'cuda/rw_kernel.cu']), ['cuda/rw.cpp', 'cuda/rw_kernel.cu'],
extra_compile_args=extra_compile_args),
] ]
__version__ = '1.4.4' __version__ = '1.4.5'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -7,7 +7,7 @@ from .radius import radius, radius_graph ...@@ -7,7 +7,7 @@ from .radius import radius, radius_graph
from .rw import random_walk from .rw import random_walk
from .sampler import neighbor_sampler from .sampler import neighbor_sampler
__version__ = '1.4.4' __version__ = '1.4.5'
__all__ = [ __all__ = [
'graclus_cluster', 'graclus_cluster',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment