Commit af2325bb authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent b5aa7bc0
...@@ -91,6 +91,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, ...@@ -91,6 +91,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor> torch::Tensor>
padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount, padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor binptr) { torch::Tensor binptr) {
// TODO: Add checks
cudaSetDevice(rowcount.get_device()); cudaSetDevice(rowcount.get_device());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
...@@ -148,9 +150,9 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src, ...@@ -148,9 +150,9 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
int64_t lane_idx = thread_idx % F; int64_t lane_idx = thread_idx % F;
int64_t index_idx = __ldg(index + row_idx); int64_t index_idx = __ldg(index + row_idx);
scalar_tmp = fill_value; scalar_t tmp = fill_value;
if (index_idx != -1) { if (index_idx != -1) {
tmp = src[__ldg(col + index_idx) + lane_idx]; tmp = src[__ldg(col + index_idx) * F + lane_idx];
} }
out[thread_idx] = tmp; out[thread_idx] = tmp;
...@@ -160,6 +162,8 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src, ...@@ -160,6 +162,8 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col, torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
torch::Tensor index, torch::Tensor index,
torch::Tensor fill_value) { torch::Tensor fill_value) {
// TODO: Add checks
cudaSetDevice(src.get_device()); cudaSetDevice(src.get_device());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; size_t mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
......
...@@ -19,7 +19,7 @@ padded_index(torch::Tensor rowptr, torch::Tensor rowcount, ...@@ -19,7 +19,7 @@ padded_index(torch::Tensor rowptr, torch::Tensor rowcount,
torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor col, torch::Tensor padded_index_select(torch::Tensor src, torch::Tensor col,
torch::Tensor index, torch::Tensor index,
torch::Tensor fill_value) { torch::Tensor fill_value) {
return padded_index_select(src, col, index, fill_value); return padded_index_select_cuda(src, col, index, fill_value);
} }
static auto registry = static auto registry =
......
...@@ -12,6 +12,28 @@ def test_padded_index_select(device): ...@@ -12,6 +12,28 @@ def test_padded_index_select(device):
start = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
row = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 3])
col = torch.tensor([0, 1, 2, 3, 0, 2, 3, 1, 3, 2])
idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
adj = SparseTensor(row=row, col=col).to(device)
binptr = torch.tensor([0, 3, 5], device=device)
idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
adj.storage.rowptr(), adj.storage.rowcount(), binptr)
print(size)
print(length)
print(offset)
print(idx)
print(mask)
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
out = torch.ops.torch_sparse.padded_index_select(x, adj.storage.col(), idx,
torch.tensor(0.))
print(out)
dataset = Planetoid('/tmp/Planetoid', name='PubMed') dataset = Planetoid('/tmp/Planetoid', name='PubMed')
data = dataset[0] data = dataset[0]
row, col = data.edge_index.to(device) row, col = data.edge_index.to(device)
...@@ -43,7 +65,7 @@ def test_padded_index_select(device): ...@@ -43,7 +65,7 @@ def test_padded_index_select(device):
print(mask[:10]) print(mask[:10])
print(idx[:10]) print(idx[:10])
x = torch.randn(data.num_nodes, 128).to(device) x = torch.randn(data.num_nodes, 256).to(device)
for i in range(110): for i in range(110):
if i == 10: if i == 10:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment