"src/vscode:/vscode.git/clone" did not exist on "6394d905da45236670570ae87803afd5c4cddb07"
Commit 62891fa0 authored by rusty1s's avatar rusty1s
Browse files

arange interleave implementation within PyTorch

parent e61e3d45
#include <torch/extension.h>
#include "compat.h"
at::Tensor arange_interleave(at::Tensor start, at::Tensor repeat) {
auto count = repeat.sum().DATA_PTR<int64_t>()[0];
auto out = at::empty(count, start.options());
auto repeat_data = repeat.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(start.scalar_type(), "arange_interleave", [&] {
auto start_data = start.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
int i = 0;
for (int start_idx = 0; start_idx < start.size(0); start_idx++) {
scalar_t init = start_data[start_idx];
for (scalar_t rep_idx = 0; rep_idx < repeat_data[start_idx]; rep_idx++) {
out_data[i] = init + rep_idx;
i++;
}
}
});
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("arange_interleave", &arange_interleave, "Arange Interleave (CPU)");
}
......@@ -12,11 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [
CppExtension('torch_sparse.arange_interleave_cpu',
['cpu/arange_interleave.cpp'],
extra_compile_args=extra_compile_args),
CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'],
extra_compile_args=extra_compile_args),
CppExtension(
'torch_sparse.spspmm_cpu',
['cpu/spspmm.cpp'],
extra_compile_args=extra_compile_args,
),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
......@@ -33,17 +33,22 @@ if CUDA_HOME is not None and GPU:
extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [
CUDAExtension('torch_sparse.spmm_cuda',
['cuda/spmm.cpp', 'cuda/spmm_kernel.cu'],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args),
CUDAExtension('torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args),
CUDAExtension('torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu'],
extra_compile_args=extra_compile_args),
CUDAExtension(
'torch_sparse.spmm_cuda',
['cuda/spmm.cpp', 'cuda/spmm_kernel.cu'],
extra_compile_args=extra_compile_args,
),
CUDAExtension(
'torch_sparse.spspmm_cuda',
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
extra_link_args=extra_link_args,
extra_compile_args=extra_compile_args,
),
CUDAExtension(
'torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu'],
extra_compile_args=extra_compile_args,
),
]
__version__ = '0.4.3'
......
import torch
from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu
def arange_interleave(start, repeat):
assert start.device == repeat.device
assert repeat.dtype == torch.long
assert start.dim() == 1
assert repeat.dim() == 1
assert start.numel() == repeat.numel()
if start.is_cuda:
raise NotImplementedError
return arange_interleave_cpu.arange_interleave(start, repeat)
def index_select(src, dim, idx):
dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1
idx = idx.to(src.device)
if dim == 0:
(_, col), value = src.coo()
(row, col), value = src.coo()
rowcount = src.storage.rowcount
rowptr = src.storage.rowptr
old_rowptr = src.storage.rowptr
rowcount = rowcount[idx]
tmp = torch.arange(rowcount.size(0), device=rowcount.device)
row = tmp.repeat_interleave(rowcount)
perm = arange_interleave(rowptr[idx], rowcount)
# Creates an "arange interleave" tensor of col indices.
rowptr = torch.cat([row.new_zeros(1), rowcount.cumsum(0)], dim=0)
perm = torch.arange(row.size(0), device=row.device)
perm += (old_rowptr[idx] - rowptr[:-1])[row]
col = col[perm]
index = torch.stack([row, col], dim=0)
......@@ -38,17 +30,23 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__(
index, value, sparse_size, rowcount=rowcount, is_sorted=True)
storage = src.storage.__class__(index, value, sparse_size,
rowcount=rowcount, rowptr=rowptr,
is_sorted=True)
elif dim == 1:
colptr, row, value = src.csc()
old_colptr, row, value = src.csc()
colcount = src.storage.colcount
colcount = colcount[idx]
tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount)
perm = arange_interleave(colptr[idx], colcount)
# Creates an "arange interleave" tensor of row indices.
colptr = torch.cat([col.new_zeros(1), colcount.cumsum(0)], dim=0)
perm = torch.arange(col.size(0), device=col.device)
perm += (old_colptr[idx] - colptr[:-1])[col]
row = row[perm]
csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr]
......@@ -58,17 +56,13 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__(
index,
value,
sparse_size,
colcount=colcount,
csc2csr=csc2csr,
is_sorted=True)
storage = src.storage.__class__(index, value, sparse_size,
colcount=colcount, csc2csr=csc2csr,
is_sorted=True)
else:
storage = src.storage.apply_value(lambda x: x.index_select(
dim - 1, idx))
storage = src.storage.apply_value(
lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage)
......@@ -86,7 +80,7 @@ def index_select_nnz(src, idx, layout=None):
value = value[idx]
# There is no other information we can maintain...
storage = src.storage.__class__(
index, value, src.sparse_size(), is_sorted=True)
storage = src.storage.__class__(index, value, src.sparse_size(),
is_sorted=True)
return src.from_storage(storage)
......@@ -487,23 +487,36 @@ if __name__ == '__main__':
import time # noqa
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/Cora', 'Cora')
dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device)
value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index, value)
# value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index)
perm = torch.arange(data.num_nodes)
perm = torch.randperm(data.num_nodes)
index = torch.tensor([
[0, 1, 1, 2, 2],
[1, 2, 2, 2, 3],
])
value = torch.tensor([1, 2, 3, 4, 5])
for _ in range(10):
x = torch.randn(1000, 1000, device=device).sum()
mat = SparseTensor(index, value)
print(mat)
print(mat.coalesce())
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(100):
mat[perm]
torch.cuda.synchronize()
print(time.perf_counter() - t)
# index = torch.tensor([
# [0, 1, 1, 2, 2],
# [1, 2, 2, 2, 3],
# ])
# value = torch.tensor([1, 2, 3, 4, 5])
# mat = SparseTensor(index, value)
# print(mat)
# print(mat.coalesce())
# index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment