Commit 2dd14df1 authored by rusty1s's avatar rusty1s
Browse files

return edge indices on GPU

parent f1628ea0
......@@ -31,23 +31,20 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
for (auto n = b; n < e; n++) {
auto n_cur = start_data[n];
int64_t e_cur = -1;
auto offset = n * (walk_length + 1);
n_out_data[offset] = n_cur;
int64_t n_cur = start_data[n], e_cur, row_start, row_end, rnd;
n_out_data[n * (walk_length + 1)] = n_cur;
int64_t row_start, row_end, rnd;
for (auto l = 0; l < walk_length; l++) {
row_start = rowptr_data[n_cur], row_end = rowptr_data[n_cur + 1];
if (row_end - row_start == 0) {
n_cur = n;
e_cur = -1;
} else {
rnd = int64_t(rand_data[n * walk_length + l] * (row_end - row_start));
e_cur = row_start + rnd;
n_cur = col_data[e_cur];
}
n_out_data[offset + l + 1] = n_cur;
n_out_data[n * (walk_length + 1) + (l + 1)] = n_cur;
e_out_data[n * walk_length + l] = e_cur;
}
}
......
......@@ -10,25 +10,27 @@
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const float *rand, int64_t *n_out,
int64_t *e_out, int64_t walk_length,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
out[thread_idx] = start[thread_idx];
int64_t n_cur = start[thread_idx], e_cur, row_start, row_end, rnd;
int64_t row_start, row_end, i, cur;
for (int64_t l = 1; l <= walk_length; l++) {
i = (l - 1) * numel + thread_idx;
cur = out[i];
row_start = rowptr[cur], row_end = rowptr[cur + 1];
n_out[thread_idx] = n_cur;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[n_cur], row_end = rowptr[n_cur + 1];
if (row_end - row_start == 0) {
out[l * numel + thread_idx] = cur;
e_cur = -1;
} else {
out[l * numel + thread_idx] =
col[row_start + int64_t(rand[i] * (row_end - row_start))];
rnd = int64_t(rand[l * numel + thread_idx] * (row_end - row_start));
e_cur = row_start + rnd;
n_cur = col[e_cur];
}
n_out[(l + 1) * numel + thread_idx] = n_cur;
e_out[l * numel + thread_idx] = e_cur;
}
}
}
......@@ -47,13 +49,16 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat));
auto out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto n_out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto e_out = torch::empty({walk_length, start.size(0)}, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
n_out.data_ptr<int64_t>(), e_out.data_ptr<int64_t>(), walk_length,
start.numel());
return out.t().contiguous();
return n_out.t().contiguous();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment