"...text-generation-inference.git" did not exist on "0d96468ebb1ca0141d7a23b2fdfcef9a7ef7bb81"
Commit 62891fa0 authored by rusty1s's avatar rusty1s
Browse files

arange interleave implementation within PyTorch

parent e61e3d45
#include <torch/extension.h>
#include "compat.h"
at::Tensor arange_interleave(at::Tensor start, at::Tensor repeat) {
auto count = repeat.sum().DATA_PTR<int64_t>()[0];
auto out = at::empty(count, start.options());
auto repeat_data = repeat.DATA_PTR<int64_t>();
AT_DISPATCH_ALL_TYPES(start.scalar_type(), "arange_interleave", [&] {
auto start_data = start.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
int i = 0;
for (int start_idx = 0; start_idx < start.size(0); start_idx++) {
scalar_t init = start_data[start_idx];
for (scalar_t rep_idx = 0; rep_idx < repeat_data[start_idx]; rep_idx++) {
out_data[i] = init + rep_idx;
i++;
}
}
});
return out;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("arange_interleave", &arange_interleave, "Arange Interleave (CPU)");
}
...@@ -12,11 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): ...@@ -12,11 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args += ['-DVERSION_GE_1_3'] extra_compile_args += ['-DVERSION_GE_1_3']
ext_modules = [ ext_modules = [
CppExtension('torch_sparse.arange_interleave_cpu', CppExtension(
['cpu/arange_interleave.cpp'], 'torch_sparse.spspmm_cpu',
extra_compile_args=extra_compile_args), ['cpu/spspmm.cpp'],
CppExtension('torch_sparse.spspmm_cpu', ['cpu/spspmm.cpp'], extra_compile_args=extra_compile_args,
extra_compile_args=extra_compile_args), ),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
...@@ -33,17 +33,22 @@ if CUDA_HOME is not None and GPU: ...@@ -33,17 +33,22 @@ if CUDA_HOME is not None and GPU:
extra_link_args = ['-lcusparse', '-l', 'cusparse'] extra_link_args = ['-lcusparse', '-l', 'cusparse']
ext_modules += [ ext_modules += [
CUDAExtension('torch_sparse.spmm_cuda', CUDAExtension(
['cuda/spmm.cpp', 'cuda/spmm_kernel.cu'], 'torch_sparse.spmm_cuda',
extra_link_args=extra_link_args, ['cuda/spmm.cpp', 'cuda/spmm_kernel.cu'],
extra_compile_args=extra_compile_args), extra_compile_args=extra_compile_args,
CUDAExtension('torch_sparse.spspmm_cuda', ),
['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'], CUDAExtension(
extra_link_args=extra_link_args, 'torch_sparse.spspmm_cuda',
extra_compile_args=extra_compile_args), ['cuda/spspmm.cpp', 'cuda/spspmm_kernel.cu'],
CUDAExtension('torch_sparse.unique_cuda', extra_link_args=extra_link_args,
['cuda/unique.cpp', 'cuda/unique_kernel.cu'], extra_compile_args=extra_compile_args,
extra_compile_args=extra_compile_args), ),
CUDAExtension(
'torch_sparse.unique_cuda',
['cuda/unique.cpp', 'cuda/unique_kernel.cu'],
extra_compile_args=extra_compile_args,
),
] ]
__version__ = '0.4.3' __version__ = '0.4.3'
......
import torch import torch
from torch_sparse.storage import get_layout from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu
def arange_interleave(start, repeat):
assert start.device == repeat.device
assert repeat.dtype == torch.long
assert start.dim() == 1
assert repeat.dim() == 1
assert start.numel() == repeat.numel()
if start.is_cuda:
raise NotImplementedError
return arange_interleave_cpu.arange_interleave(start, repeat)
def index_select(src, dim, idx): def index_select(src, dim, idx):
dim = src.dim() + dim if dim < 0 else dim dim = src.dim() + dim if dim < 0 else dim
assert idx.dim() == 1 assert idx.dim() == 1
idx = idx.to(src.device)
if dim == 0: if dim == 0:
(_, col), value = src.coo() (row, col), value = src.coo()
rowcount = src.storage.rowcount rowcount = src.storage.rowcount
rowptr = src.storage.rowptr old_rowptr = src.storage.rowptr
rowcount = rowcount[idx] rowcount = rowcount[idx]
tmp = torch.arange(rowcount.size(0), device=rowcount.device) tmp = torch.arange(rowcount.size(0), device=rowcount.device)
row = tmp.repeat_interleave(rowcount) row = tmp.repeat_interleave(rowcount)
perm = arange_interleave(rowptr[idx], rowcount)
# Creates an "arange interleave" tensor of col indices.
rowptr = torch.cat([row.new_zeros(1), rowcount.cumsum(0)], dim=0)
perm = torch.arange(row.size(0), device=row.device)
perm += (old_rowptr[idx] - rowptr[:-1])[row]
col = col[perm] col = col[perm]
index = torch.stack([row, col], dim=0) index = torch.stack([row, col], dim=0)
...@@ -38,17 +30,23 @@ def index_select(src, dim, idx): ...@@ -38,17 +30,23 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)]) sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])
storage = src.storage.__class__( storage = src.storage.__class__(index, value, sparse_size,
index, value, sparse_size, rowcount=rowcount, is_sorted=True) rowcount=rowcount, rowptr=rowptr,
is_sorted=True)
elif dim == 1: elif dim == 1:
colptr, row, value = src.csc() old_colptr, row, value = src.csc()
colcount = src.storage.colcount colcount = src.storage.colcount
colcount = colcount[idx] colcount = colcount[idx]
tmp = torch.arange(colcount.size(0), device=row.device) tmp = torch.arange(colcount.size(0), device=row.device)
col = tmp.repeat_interleave(colcount) col = tmp.repeat_interleave(colcount)
perm = arange_interleave(colptr[idx], colcount)
# Creates an "arange interleave" tensor of row indices.
colptr = torch.cat([col.new_zeros(1), colcount.cumsum(0)], dim=0)
perm = torch.arange(col.size(0), device=col.device)
perm += (old_colptr[idx] - colptr[:-1])[col]
row = row[perm] row = row[perm]
csc2csr = (colcount.size(0) * row + col).argsort() csc2csr = (colcount.size(0) * row + col).argsort()
index = torch.stack([row, col], dim=0)[:, csc2csr] index = torch.stack([row, col], dim=0)[:, csc2csr]
...@@ -58,17 +56,13 @@ def index_select(src, dim, idx): ...@@ -58,17 +56,13 @@ def index_select(src, dim, idx):
sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)]) sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])
storage = src.storage.__class__( storage = src.storage.__class__(index, value, sparse_size,
index, colcount=colcount, csc2csr=csc2csr,
value, is_sorted=True)
sparse_size,
colcount=colcount,
csc2csr=csc2csr,
is_sorted=True)
else: else:
storage = src.storage.apply_value(lambda x: x.index_select( storage = src.storage.apply_value(
dim - 1, idx)) lambda x: x.index_select(dim - 1, idx))
return src.from_storage(storage) return src.from_storage(storage)
...@@ -86,7 +80,7 @@ def index_select_nnz(src, idx, layout=None): ...@@ -86,7 +80,7 @@ def index_select_nnz(src, idx, layout=None):
value = value[idx] value = value[idx]
# There is no other information we can maintain... # There is no other information we can maintain...
storage = src.storage.__class__( storage = src.storage.__class__(index, value, src.sparse_size(),
index, value, src.sparse_size(), is_sorted=True) is_sorted=True)
return src.from_storage(storage) return src.from_storage(storage)
...@@ -487,23 +487,36 @@ if __name__ == '__main__': ...@@ -487,23 +487,36 @@ if __name__ == '__main__':
import time # noqa import time # noqa
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# dataset = Reddit('/tmp/Reddit') # dataset = Reddit('/tmp/Reddit')
dataset = Planetoid('/tmp/Cora', 'Cora') dataset = Planetoid('/tmp/PubMed', 'PubMed')
data = dataset[0].to(device) data = dataset[0].to(device)
value = torch.randn(data.num_edges, 10) # value = torch.randn(data.num_edges, 10)
mat = SparseTensor(data.edge_index, value) mat = SparseTensor(data.edge_index)
perm = torch.arange(data.num_nodes)
perm = torch.randperm(data.num_nodes)
index = torch.tensor([ for _ in range(10):
[0, 1, 1, 2, 2], x = torch.randn(1000, 1000, device=device).sum()
[1, 2, 2, 2, 3],
])
value = torch.tensor([1, 2, 3, 4, 5])
mat = SparseTensor(index, value) torch.cuda.synchronize()
print(mat) t = time.perf_counter()
print(mat.coalesce()) for _ in range(100):
mat[perm]
torch.cuda.synchronize()
print(time.perf_counter() - t)
# index = torch.tensor([
# [0, 1, 1, 2, 2],
# [1, 2, 2, 2, 3],
# ])
# value = torch.tensor([1, 2, 3, 4, 5])
# mat = SparseTensor(index, value)
# print(mat)
# print(mat.coalesce())
# index = torch.tensor([0, 1, 2]) # index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool) # mask = torch.zeros(data.num_nodes, dtype=torch.bool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment