Commit 5d12a68a authored by rusty1s's avatar rusty1s
Browse files

fix rw isolated nodes bug

parent 2bf5e763
...@@ -16,7 +16,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -16,7 +16,7 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto rand = torch::rand({start.size(0), walk_length}, auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat)); start.options().dtype(torch::kFloat));
auto out = torch::full({start.size(0), walk_length + 1}, -1, start.options()); auto out = torch::empty({start.size(0), walk_length + 1}, start.options());
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
...@@ -29,12 +29,16 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -29,12 +29,16 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto offset = n * (walk_length + 1); auto offset = n * (walk_length + 1);
out_data[offset] = cur; out_data[offset] = cur;
int64_t row_start, row_end; int64_t row_start, row_end, rnd;
for (auto l = 1; l <= walk_length; l++) { for (auto l = 1; l <= walk_length; l++) {
row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1]; row_start = rowptr_data[cur], row_end = rowptr_data[cur + 1];
if (row_end - row_start == 0) {
cur = col_data[row_start + int64_t(rand_data[n * walk_length + (l - 1)] * cur = n;
(row_end - row_start))]; } else {
rnd = int64_t(rand_data[n * walk_length + (l - 1)] *
(row_end - row_start));
cur = col_data[row_start + rnd];
}
out_data[offset + l] = cur; out_data[offset + l] = cur;
} }
} }
......
...@@ -23,8 +23,12 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr, ...@@ -23,8 +23,12 @@ __global__ void uniform_random_walk_kernel(const int64_t *rowptr,
cur = out[i]; cur = out[i];
row_start = rowptr[cur], row_end = rowptr[cur + 1]; row_start = rowptr[cur], row_end = rowptr[cur + 1];
out[l * numel + thread_idx] = if (row_end - row_start == 0) {
col[row_start + int64_t(rand[i] * (row_end - row_start))]; out[l * numel + thread_idx] = cur;
} else {
out[l * numel + thread_idx] =
col[row_start + int64_t(rand[i] * (row_end - row_start))];
}
} }
} }
} }
...@@ -43,7 +47,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, ...@@ -43,7 +47,7 @@ torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
auto rand = torch::rand({start.size(0), walk_length}, auto rand = torch::rand({start.size(0), walk_length},
start.options().dtype(torch::kFloat)); start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options()); auto out = torch::empty({walk_length + 1, start.size(0)}, start.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>( uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
......
...@@ -12,7 +12,7 @@ def test_rw(device): ...@@ -12,7 +12,7 @@ def test_rw(device):
start = tensor([0, 1, 2, 3, 4], torch.long, device) start = tensor([0, 1, 2, 3, 4], torch.long, device)
walk_length = 10 walk_length = 10
out = random_walk(row, col, start, walk_length, coalesced=True) out = random_walk(row, col, start, walk_length)
assert out[:, 0].tolist() == start.tolist() assert out[:, 0].tolist() == start.tolist()
for n in range(start.size(0)): for n in range(start.size(0)):
...@@ -20,3 +20,11 @@ def test_rw(device): ...@@ -20,3 +20,11 @@ def test_rw(device):
for i in range(1, walk_length): for i in range(1, walk_length):
assert out[n, i].item() in col[row == cur].tolist() assert out[n, i].item() in col[row == cur].tolist()
cur = out[n, i].item() cur = out[n, i].item()
row = tensor([0, 1], torch.long, device)
col = tensor([1, 0], torch.long, device)
start = tensor([0, 1, 2], torch.long, device)
walk_length = 4
out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1], [1, 0, 1, 0], [2, 2, 2, 2]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment