Commit 7bb94638 authored by rusty1s's avatar rusty1s
Browse files

new rw cuda implementation

parent ba9f2ed2
......@@ -13,16 +13,12 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto num_nodes = rowptr.size(0) - 1;
auto deg = rowptr.narrow(0, 1, num_nodes) - rowptr.narrow(0, 0, num_nodes);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto out = torch::full({start.size(0), walk_length + 1}, -1, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto deg_data = deg.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>();
......@@ -33,10 +29,12 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto offset = n * (walk_length + 1);
out_data[offset] = cur;
int64_t row_start, row_end;
for (auto l = 1; l <= walk_length; l++) {
cur = col_data[rowptr_data[cur] +
int64_t(rand_data[n * walk_length + (l - 1)] *
deg_data[cur])];
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1];
cur = col_data[row_start + int64_t(rand_data[n * walk_length + (l - 1)] *
(row_end - row_start))];
out_data[offset + l] = cur;
}
}
......
#include "grid_cpu.h"
#include "grid_cuda.h"
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
......@@ -12,7 +10,7 @@
template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t N, int64_t D, int64_t numel) {
int64_t *out, int64_t D, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
......@@ -62,11 +60,12 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), pos.size(0), pos.size(1), out.numel());
out.data_ptr<int64_t>(), pos.size(1), out.numel());
});
return out;
......
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes);
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
CHECK_CUDA(start);
return rw_cuda(row, col, start, walk_length, p, q, num_nodes);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rw", &rw, "Random Walk Sampling (CUDA)");
}
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
out[thread_idx] = start[thread_idx];
int64_t row_start, row_end, i, cur;
for (int64_t l = 1; l <= walk_length; l++) {
i = (l - 1) * numel + thread_idx;
cur = out[i];
row_start = rowptr[cur], row_end = rowptr[cur + 1];
out[l * numel + n] =
col[row_start + int64_t(rand[i] * (row_end - row_start))];
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include <torch/extension.h>
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length,
double p, double q);
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_rw_kernel(
const int64_t *__restrict__ row, const int64_t *__restrict__ col,
const int64_t *__restrict__ deg, const int64_t *__restrict__ start,
const float *__restrict__ rand, int64_t *__restrict__ out,
const size_t walk_length, const size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t n = index; n < numel; n += stride) {
out[n] = start[n];
for (ptrdiff_t l = 1; l <= walk_length; l++) {
auto i = (l - 1) * numel + n;
auto cur = out[i];
out[l * numel + n] = col[row[cur] + int64_t(rand[i] * deg[cur])];
}
}
}
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
size_t walk_length, float p, float q, size_t num_nodes) {
cudaSetDevice(row.get_device());
auto deg = degree(row, num_nodes);
row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
auto rand = at::rand({(int64_t)walk_length, start.size(0)},
start.options().dtype(at::kFloat));
auto out =
at::full({(int64_t)walk_length + 1, start.size(0)}, -1, start.options());
uniform_rw_kernel<<<BLOCKS(start.numel()), THREADS>>>(
row.DATA_PTR<int64_t>(), col.DATA_PTR<int64_t>(), deg.DATA_PTR<int64_t>(),
start.DATA_PTR<int64_t>(), rand.DATA_PTR<float>(),
out.DATA_PTR<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment