Commit af2325bb authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent b5aa7bc0
......@@ -91,6 +91,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr) {
// TODO: Add checks
cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
......@@ -148,9 +150,9 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
int64_t lane_idx = thread_idx % F;
int64_t index_idx = __ldg(index + row_idx);
scalar_tmp = fill_value;
scalar_t tmp = fill_value;
if (index_idx != -1) {
tmp = src[__ldg(col + index_idx) + lane_idx];
tmp = src[__ldg(col + index_idx) * F + lane_idx];
}
out[thread_idx] = tmp;
......@@ -160,6 +162,8 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor fill_value) {
// TODO: Add checks
cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
......
......@@ -19,7 +19,7 @@ padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor col,
torch::Tensor index,
torch::Tensor fill_value) {
return padded_index_select(src, col, index, fill_value);
return padded_index_select_cuda(src, col, index, fill_value);
}
static auto registry =
......
......@@ -12,6 +12,28 @@ def test_padded_index_select(device):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
adj = SparseTensor(row=row, col=col).to(device)
binptr = torch.tensor([0, 3, 5], device=device)
idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
adj.storage.rowptr(), adj.storage.rowcount(), binptr)
print(size)
print(length)
print(offset)
print(idx)
print(mask)
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
out = torch.ops.torch_sparse.padded_index_select(x, adj.storage.col(), idx,
torch.tensor(0.))
print(out)
dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0]
row, col = data.edge_index.to(device)
......@@ -43,7 +65,7 @@ def test_padded_index_select(device):
print(mask[:10])
print(idx[:10])
x = torch.randn(data.num_nodes, 128).to(device)
x = torch.randn(data.num_nodes, 256).to(device)
for i in range(110):
if i == 10:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment