Commit d1d4ec3c authored by rusty1s's avatar rusty1s
Browse files

parallel convert on CPU, fix bug for nnz=0

parent 99117398
#include "convert_cpu.h" #include "convert_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h" #include "utils.h"
torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M) { torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M) {
...@@ -8,18 +10,25 @@ torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M) { ...@@ -8,18 +10,25 @@ torch::Tensor ind2ptr_cpu(torch::Tensor ind, int64_t M) {
auto ind_data = ind.data_ptr<int64_t>(); auto ind_data = ind.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
int64_t numel = ind.numel(), idx = ind_data[0], next_idx; int64_t numel = ind.numel();
for (int64_t i = 0; i <= idx; i++)
out_data[i] = 0;
for (int64_t i = 0; i < numel - 1; i++) { if (numel == 0)
next_idx = ind_data[i + 1]; return out.zero_();
for (int64_t j = idx; j < next_idx; j++)
out_data[j + 1] = i + 1; for (int64_t i = 0; i <= ind_data[0]; i++)
idx = next_idx; out_data[i] = 0;
}
for (int64_t i = idx + 1; i < M + 1; i++) int64_t grain_size = at::internal::GRAIN_SIZE;
at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
int64_t idx = ind_data[begin], next_idx;
for (int64_t i = begin; i < std::min(end, numel - 1); i++) {
next_idx = ind_data[i + 1];
for (; idx < next_idx; idx++)
out_data[idx + 1] = i + 1;
}
});
for (int64_t i = ind_data[numel - 1] + 1; i < M + 1; i++)
out_data[i] = numel; out_data[i] = numel;
return out; return out;
...@@ -31,13 +40,18 @@ torch::Tensor ptr2ind_cpu(torch::Tensor ptr, int64_t E) { ...@@ -31,13 +40,18 @@ torch::Tensor ptr2ind_cpu(torch::Tensor ptr, int64_t E) {
auto ptr_data = ptr.data_ptr<int64_t>(); auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
int64_t idx = ptr_data[0], next_idx; int64_t numel = ptr.numel();
for (int64_t i = 0; i < ptr.numel() - 1; i++) {
next_idx = ptr_data[i + 1]; int64_t grain_size = at::internal::GRAIN_SIZE;
for (int64_t e = idx; e < next_idx; e++) at::parallel_for(0, numel - 1, grain_size, [&](int64_t begin, int64_t end) {
out_data[e] = i; int64_t idx = ptr_data[begin], next_idx;
idx = next_idx; for (int64_t i = begin; i < end; i++) {
} next_idx = ptr_data[i + 1];
for (int64_t e = idx; e < next_idx; e++)
out_data[e] = i;
idx = next_idx;
}
});
return out; return out;
} }
...@@ -28,6 +28,10 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) { ...@@ -28,6 +28,10 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
cudaSetDevice(ind.get_device()); cudaSetDevice(ind.get_device());
auto out = torch::empty(M + 1, ind.options()); auto out = torch::empty(M + 1, ind.options());
if (ind.numel() == 0)
return out.zero_();
auto ind_data = ind.data_ptr<int64_t>(); auto ind_data = ind.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
...@@ -7,6 +7,23 @@ from torch_sparse.storage import SparseStorage ...@@ -7,6 +7,23 @@ from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('device', devices)
def test_ind2ptr(device):
row = tensor([2, 2, 4, 5, 5, 6], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 2, 2, 3, 5, 6, 6]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 6)
assert row.tolist() == [2, 2, 4, 5, 5, 6]
row = tensor([], torch.long, device)
rowptr = torch.ops.torch_sparse.ind2ptr(row, 8)
assert rowptr.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0]
row = torch.ops.torch_sparse.ptr2ind(rowptr, 0)
assert row.tolist() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device): def test_storage(dtype, device):
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device) row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment