Commit f1628ea0 authored by rusty1s's avatar rusty1s
Browse files

return edge indices

parent e68dcf3b
...@@ -18,35 +18,40 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -18,35 +18,40 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand = torch::rand({start.size(0), walk_length}, auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat)); start.options().dtype(torch::kFloat));
auto out = torch::empty({start.size(0), walk_length + 1}, start.options()); auto n_out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto e_out = torch::empty({start.size(0), walk_length}, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto start_data = start.data_ptr<int64_t>(); auto start_data = start.data_ptr<int64_t>();
auto rand_data = rand.data_ptr<float>(); auto rand_data = rand.data_ptr<float>();
auto out_data = out.data_ptr<int64_t>(); auto n_out_data = n_out.data_ptr<int64_t>();
auto e_out_data = e_out.data_ptr<int64_t>();
int64_t grain_size = at::internal::GRAIN_SIZE / walk_length; int64_t grain_size = at::internal::GRAIN_SIZE / walk_length;
at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) { at::parallel_for(0, start.numel(), grain_size, [&](int64_t b, int64_t e) {
for (auto n = b; n < e; n++) { for (auto n = b; n < e; n++) {
auto cur = start_data[n]; auto n_cur = start_data[n];
int64_t e_cur = -1;
auto offset = n * (walk_length + 1); auto offset = n * (walk_length + 1);
out_data[offset] = cur; n_out_data[offset] = n_cur;
int64_t row_start, row_end, rnd; int64_t row_start, row_end, rnd;
for (auto l = 1; l <= walk_length; l++) { for (auto l = 0; l < walk_length; l++) {
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1]; row_start = rowptr_data[n_cur], row_end = rowptr_data[n_cur + 1];
if (row_end - row_start == 0) { if (row_end - row_start == 0) {
cur = n; n_cur = n;
e_cur = -1;
} else { } else {
rnd = int64_t(rand_data[n * walk_length + (l - 1)] * rnd = int64_t(rand_data[n * walk_length + l] * (row_end - row_start));
(row_end - row_start)); e_cur = row_start + rnd;
cur = col_data[row_start + rnd]; n_cur = col_data[e_cur];
} }
out_data[offset + l] = cur; n_out_data[offset + l + 1] = n_cur;
e_out_data[n * walk_length + l] = e_cur;
} }
} }
}); });
return out; return n_out;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment