Commit 16b976c7 authored by rusty1s's avatar rusty1s
Browse files

new cuda layout

parent 9a9d9732
#include <torch/torch.h> #include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
#include "graclus.cpp" at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes);
#include "grid.cpp"
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int num_nodes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("grid", &grid, "Grid (CUDA)"); m.def("grid", &grid, "Grid (CUDA)");
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("weighted_graclus", &weighted_graclus, "Weightes Graclus (CUDA)");
} }
#pragma once
#include <ATen/ATen.h>
#include "common.cuh"
#define BLUE_PROB 0.53406
__global__ void color_kernel(int64_t *cluster, size_t num_nodes) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < num_nodes; i += stride) {
}
}
inline bool color(at::Tensor cluster) {
color_kernel<scalar_t><<<BLOCKS(cluster.size(0)), THREADS>>>(
cluster.data<int64_t>(), cluster.size(0));
return true;
}
#pragma once
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
inline at::Tensor degree(at::Tensor index, int num_nodes) {
auto zero = at::zeros(index.type(), {num_nodes});
auto one = at::ones(index.type(), {index.size(0)});
return zero.scatter_add_(0, index, one);
}
#include <torch/torch.h>
#include "../include/degree.cpp"
#include "../include/loop.cpp"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
std::tie(row, col) = remove_self_loops(row, col);
auto deg = degree(row, num_nodes, row.type().scalarType());
return deg;
}
#include <ATen/ATen.h>
#include "color.cuh"
#include "common.cuh"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
// Remove self-loops.
auto mask = row != col;
row = row.masked_select(mask);
col.masked_select(mask);
// Sort by row index.
at::Tensor perm;
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Generate helper vectors.
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto prop = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes);
auto cum_deg = deg.cumsum(0);
color(cluster);
/* while (!color(cluster)) { */
/* propose(cluster, prop, row, col, weight, deg, cum_deg); */
/* response(cluster, prop, row, col, weight, deg, cum_deg); */
/* } */
return cluster;
}
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
int num_nodes) {
// Remove self-loops.
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
weight = weight.masked_select(mask);
// Sort by row index.
at::Tensor perm;
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
weight = weight.index_select(0, perm);
// Generate helper vectors.
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto prop = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes);
auto cum_deg = deg.cumsum(0);
color(cluster);
/* while (!color(cluster)) { */
/* weighted_propose(cluster, prop, row, col, weight, deg, cum_deg); */
/* weighted_response(cluster, prop, row, col, weight, deg, cum_deg); */
/* } */
return cluster;
}
#include <torch/torch.h>
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#define THREADS 1024 #include "common.cuh"
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> template <typename scalar_t>
__global__ void __global__ void
grid_cuda_kernel(int64_t *cluster, grid_kernel(int64_t *cluster, at::cuda::detail::TensorInfo<scalar_t, int> pos,
at::cuda::detail::TensorInfo<scalar_t, int> pos,
scalar_t *__restrict__ size, scalar_t *__restrict__ start, scalar_t *__restrict__ size, scalar_t *__restrict__ start,
scalar_t *__restrict__ end, size_t num_nodes) { scalar_t *__restrict__ end, size_t num_nodes) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -16,7 +14,7 @@ grid_cuda_kernel(int64_t *cluster, ...@@ -16,7 +14,7 @@ grid_cuda_kernel(int64_t *cluster,
int64_t c = 0, k = 1; int64_t c = 0, k = 1;
scalar_t tmp; scalar_t tmp;
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) { for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
tmp = (pos.data[i * pos.strides[0] + d * pos.strides[1]]) - start[d]; tmp = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
c += (int64_t)(tmp / size[d]) * k; c += (int64_t)(tmp / size[d]) * k;
k += (int64_t)((end[d] - start[d]) / size[d]); k += (int64_t)((end[d] - start[d]) / size[d]);
} }
...@@ -24,18 +22,17 @@ grid_cuda_kernel(int64_t *cluster, ...@@ -24,18 +22,17 @@ grid_cuda_kernel(int64_t *cluster,
} }
} }
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) { at::Tensor end) {
auto num_nodes = pos.size(0); auto cluster = at::empty(pos.type().toScalarType(at::kLong), {pos.size(0)});
auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>( grid_kernel<scalar_t><<<BLOCKS(pos.size(0)), THREADS>>>(
cluster.data<int64_t>(), cluster.data<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int>(pos), at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
size.toType(pos.type()).data<scalar_t>(), size.toType(pos.type()).data<scalar_t>(),
start.toType(pos.type()).data<scalar_t>(), start.toType(pos.type()).data<scalar_t>(),
end.toType(pos.type()).data<scalar_t>(), num_nodes); end.toType(pos.type()).data<scalar_t>(), pos.size(0));
}); });
return cluster; return cluster;
......
#ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) {
auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
#endif // DEGREE_INC
#ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)};
}
#endif // LOOP_INC
#ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(row.type(), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(row.type(), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
#endif // PERM_INC
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment