Commit c2fc5a86 authored by rusty1s's avatar rusty1s
Browse files

gpu rejection sampling

parent 85b4e410
...@@ -58,7 +58,7 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col, ...@@ -58,7 +58,7 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col,
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length; int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) { at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
for (auto n = begin; n < end; n++) { for (auto n = begin; n < end; n++) {
int64_t t = start[n], v, x, e_cur, row_start, row_end, idx; int64_t t = start[n], v, x, e_cur, row_start, row_end;
n_out[n * (walk_length + 1)] = t; n_out[n * (walk_length + 1)] = t;
...@@ -75,6 +75,7 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col, ...@@ -75,6 +75,7 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col,
for (auto l = 1; l < walk_length; l++) { for (auto l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1]; row_start = rowptr[v], row_end = rowptr[v + 1];
if (row_end - row_start == 0) { if (row_end - row_start == 0) {
e_cur = -1; e_cur = -1;
x = v; x = v;
...@@ -88,13 +89,11 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col, ...@@ -88,13 +89,11 @@ void rejection_sampling(const int64_t *rowptr, const int64_t *col,
auto r = ((double)rand() / (RAND_MAX)); // [0, 1) auto r = ((double)rand() / (RAND_MAX)); // [0, 1)
if (x == t) { if (x == t && r < prob_0)
if (r < prob_0)
break; break;
} else if (is_neighbor(rowptr, col, x, t)) { else if (is_neighbor(rowptr, col, x, t) && r < prob_1)
if (r < prob_1)
break; break;
} else if (r < prob_2) else if (r < prob_2)
break; break;
} }
} }
......
#include "rw_cuda.h" #include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <curand.h>
#include <curand_kernel.h>
#include "utils.cuh" #include "utils.cuh"
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr, __global__ void uniform_sampling_kernel(const int64_t *rowptr,
const int64_t *col, const int64_t *col,
const int64_t *start, const int64_t *start, const float *rand,
const float *rand, int64_t *n_out, int64_t *n_out, int64_t *e_out,
int64_t *e_out, int64_t walk_length, const int64_t walk_length,
int64_t numel) { const int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) { if (thread_idx < numel) {
...@@ -35,6 +38,83 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr, ...@@ -35,6 +38,83 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
} }
} }
__global__ void
rejection_sampling_kernel(unsigned int seed, const int64_t *rowptr,
const int64_t *col, const int64_t *start,
int64_t *n_out, int64_t *e_out,
const int64_t walk_length, const int64_t numel,
const double p, const double q) {
curandState_t state;
curand_init(seed, 0, 0, &state);
double max_prob = fmax(fmax(1. / p, 1.), 1. / q);
double prob_0 = 1. / p / max_prob;
double prob_1 = 1. / max_prob;
double prob_2 = 1. / q / max_prob;
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t t = start[thread_idx], v, x, e_cur, row_start, row_end;
n_out[thread_idx] = t;
row_start = rowptr[t], row_end = rowptr[t + 1];
if (row_end - row_start == 0) {
e_cur = -1;
v = t;
} else {
e_cur = row_start + (curand(&state) % (row_end - row_start));
v = col[e_cur];
}
n_out[numel + thread_idx] = v;
e_out[thread_idx] = e_cur;
for (int64_t l = 1; l < walk_length; l++) {
row_start = rowptr[v], row_end = rowptr[v + 1];
if (row_end - row_start == 0) {
e_cur = -1;
x = v;
} else if (row_end - row_start == 1) {
e_cur = row_start;
x = col[e_cur];
} else {
while (true) {
e_cur = row_start + (curand(&state) % (row_end - row_start));
x = col[e_cur];
double r = curand_uniform(&state); // (0, 1]
if (x == t && r < prob_0)
break;
bool is_neighbor = false;
row_start = rowptr[x], row_end = rowptr[x + 1];
for (int64_t i = row_start; i < row_end; i++) {
if (col[i] == t) {
is_neighbor = true;
break;
}
}
if (is_neighbor && r < prob_1)
break;
else if (r < prob_2)
break;
}
}
n_out[(l + 1) * numel + thread_idx] = x;
e_out[l * numel + thread_idx] = e_cur;
t = v;
v = x;
}
}
}
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
int64_t walk_length, double p, double q) { int64_t walk_length, double p, double q) {
...@@ -47,18 +127,26 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, ...@@ -47,18 +127,26 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start,
CHECK_INPUT(col.dim() == 1); CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1); CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options()); auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto e_out = torch::empty({walk_length, start.size(0)}, start.options()); auto e_out = torch::empty({walk_length, start.size(0)}, start.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
if (p == 1. && q == 1.) {
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
uniform_sampling_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(), start.data_ptr<int64_t>(), rand.data_ptr<float>(),
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length, n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length,
start.numel()); start.numel());
} else {
rejection_sampling_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
time(NULL), rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), n_out.data_ptr<int64_t>(),
e_out.data_ptr<int64_t>(), walk_length, start.numel(), p, q);
}
return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous()); return std::make_tuple(n_out.t().contiguous(), e_out.t().contiguous());
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment