Commit 925f9567 authored by rusty1s's avatar rusty1s
Browse files

cat and dim fix

parent 26aee002
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_add(dtype, device):
name = ('DIMACS10', 'citationCiteseer')[1]
mat_scipy = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2].tocsr()
mat = SparseTensor.from_scipy(mat_scipy)
mat1 = mat[:, 0:100000]
mat2 = mat[:, 100000:200000]
# print(mat1.shape)
# print(mat2.shape)
# 0.0159 to beat
t = time.perf_counter()
mat = sparse_add(mat1, mat2)
# print(time.perf_counter() - t)
# print(mat.nnz())
mat1 = mat_scipy[:, 0:100000]
mat2 = mat_scipy[:, 100000:200000]
t = time.perf_counter()
mat = mat1 + mat2
# print(time.perf_counter() - t)
# print(mat.nnz)
# mat1 + mat2
# mat1 = mat1.tocoo()
# mat2 = mat2.tocoo()
# row1, col1 = mat1.row, mat1.col
# row2, col2 = mat2.row, mat2.col
# idx1 = row1 * 100000 + col1
# idx2 = row2 * 100000 + col2
# t = time.perf_counter()
# np.union1d(idx1, idx2)
# print(time.perf_counter() - t)
# index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
# mat1 = SparseTensor(index)
# print()
# print(mat1.to_dense())
# index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
# mat2 = SparseTensor(index)
# print(mat2.to_dense())
# add(mat1, mat2)
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat from torch_sparse.cat import cat, cat_diag
from .utils import devices, tensor from .utils import devices, tensor
...@@ -21,29 +21,27 @@ def test_cat(device): ...@@ -21,29 +21,27 @@ def test_cat(device):
[0, 1, 0], [1, 0, 0]] [0, 1, 0], [1, 0, 0]]
assert out.storage.has_row() assert out.storage.has_row()
assert out.storage.has_rowptr() assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 1
assert out.storage.has_rowcount() assert out.storage.has_rowcount()
assert out.storage.num_cached_keys() == 1
out = cat([mat1, mat2], dim=1) out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1], assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
assert out.storage.has_row() assert out.storage.has_row()
assert not out.storage.has_rowptr() assert not out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 2 assert out.storage.num_cached_keys() == 2
assert out.storage.has_colcount()
assert out.storage.has_colptr()
out = cat([mat1, mat2], dim=(0, 1)) out = cat_diag([mat1, mat2])
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
assert out.storage.has_row() assert out.storage.has_row()
assert out.storage.has_rowptr() assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5 assert out.storage.num_cached_keys() == 5
mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device)) mat1 = mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
out = cat([mat1, mat1], dim=-1) out = cat([mat1, mat1], dim=-1)
assert out.storage.value.size() == (mat1.nnz(), 8) assert out.storage.value().size() == (mat1.nnz(), 8)
assert out.storage.has_row() assert out.storage.has_row()
assert out.storage.has_rowptr() assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5 assert out.storage.num_cached_keys() == 5
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_index_select(dtype, device):
row = torch.tensor([0, 0, 1, 1, 2])
col = torch.tensor([0, 1, 1, 2, 1])
mat = SparseTensor(row=row, col=col)
print()
print(mat.to_dense())
pass
mat = mat.index_select(0, torch.tensor([0, 2]))
print(mat.to_dense())
from itertools import product
import pytest
import torch
from torch_sparse import rowptr_cpu
from .utils import tensor, devices
if torch.cuda.is_available():
from torch_sparse import rowptr_cuda
tests = [
{
'row': [0, 0, 1, 1, 1, 2, 2],
'size': 5,
'rowptr': [0, 2, 5, 7, 7, 7],
},
{
'row': [0, 0, 1, 1, 1, 4, 4],
'size': 5,
'rowptr': [0, 2, 5, 5, 5, 7],
},
{
'row': [2, 2, 4, 4],
'size': 7,
'rowptr': [0, 0, 0, 2, 2, 4, 4, 4],
},
]
def rowptr(row, size):
return (rowptr_cuda if row.is_cuda else rowptr_cpu).rowptr(row, size)
@pytest.mark.parametrize('test,device', product(tests, devices))
def test_rowptr(test, device):
row = tensor(test['row'], torch.long, device)
size = test['size']
expected = tensor(test['rowptr'], torch.long, device)
out = rowptr(row, size)
assert torch.all(out == expected)
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse.storage import SparseStorage, no_cache from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -79,17 +79,6 @@ def test_caching(dtype, device): ...@@ -79,17 +79,6 @@ def test_caching(dtype, device):
assert storage._csr2csc is None assert storage._csr2csc is None
assert storage.cached_keys() == [] assert storage.cached_keys() == []
with no_cache():
storage.fill_cache_()
assert storage.cached_keys() == []
@no_cache()
def do_something(storage):
return storage.fill_cache_()
storage = do_something(storage)
assert storage.cached_keys() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device): def test_utility(dtype, device):
......
import torch
torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so')
try:
torch.ops.load_library('torch_sparse/convert_cuda.so')
torch.ops.load_library('torch_sparse/diag_cuda.so')
torch.ops.load_library('torch_sparse/spmm_cuda.so')
torch.ops.load_library('torch_sparse/spspmm_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce from .coalesce import coalesce
from .transpose import transpose from .transpose import transpose
...@@ -48,3 +33,4 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ ...@@ -48,3 +33,4 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max from .reduce import sum, mean, min, max
from .matmul import (spmm_sum, spmm_add, spmm_mean, spmm_min, spmm_max, spmm, from .matmul import (spmm_sum, spmm_add, spmm_mean, spmm_min, spmm_max, spmm,
spspmm_sum, spspmm_add, spspmm, matmul) spspmm_sum, spspmm_add, spspmm, matmul)
from .cat import cat, cat_diag
from typing import List, Optional
import torch import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def cat(tensors, dim): @torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0 assert len(tensors) > 0
has_row = tensors[0].storage.has_row() if dim < 0:
has_value = tensors[0].has_value() dim = tensors[0].dim() + dim
has_rowcount = tensors[0].storage.has_rowcount()
has_colptr = tensors[0].storage.has_colptr()
has_colcount = tensors[0].storage.has_colcount()
has_csr2csc = tensors[0].storage.has_csr2csc()
has_csc2csr = tensors[0].storage.has_csc2csr()
rows, rowptrs, cols, values, sparse_size, nnzs = [], [], [], [], [0, 0], 0
rowcounts, colcounts, colptrs, csr2cscs, csc2csrs = [], [], [], [], []
if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
else:
dim = tuple([tensors[0].dim() + d if d < 0 else d for d in dim])
if dim == 0: if dim == 0:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors: for tensor in tensors:
rowptr, col, value = tensor.csr() row = tensor.storage._row
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:] if row is not None:
rowptrs += [rowptr + nnzs] rows.append(row + sparse_sizes[0])
cols += [col]
values += [value] rowptr = tensor.storage._rowptr
if rowptr is not None:
if has_row: if len(rowptrs) > 0:
rows += [tensor.storage.row + sparse_size[0]] rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
if has_rowcount:
rowcounts += [tensor.storage.rowcount] cols.append(tensor.storage._col)
sparse_size[0] += tensor.sparse_size(0) value = tensor.storage._value
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1)) if value is not None:
nnzs += tensor.nnz() values.append(value)
storage = tensors[0].storage.__class__( rowcount = tensor.storage._rowcount
row=torch.cat(rows) if has_row else None, if rowcount is not None:
rowptr=torch.cat(rowptrs), col=torch.cat(cols), rowcounts.append(rowcount)
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size, sparse_sizes[0] += tensor.sparse_size(0)
rowcount=torch.cat(rowcounts) if has_rowcount else None, sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
is_sorted=True) nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1: elif dim == 1:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors: for tensor in tensors:
row, col, value = tensor.coo() row, col, value = tensor.coo()
rows += [row]
cols += [col + sparse_size[1]] rows.append(row)
values += [value]
cols.append(tensor.storage._col + sparse_sizes[1])
if has_colcount:
colcounts += [tensor.storage.colcount] if value is not None:
values.append(value)
if has_colptr:
colptr = tensor.storage.colptr colptr = tensor.storage._colptr
colptr = colptr if len(colptrs) == 0 else colptr[1:] if colptr is not None:
colptrs += [colptr + nnzs] if len(colptrs) > 0:
colptr = colptr[1:]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0)) colptrs.append(colptr + nnz)
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz() colcount = tensor.storage._colcount
if colcount is not None:
storage = tensors[0].storage.__class__( colcounts.append(colcount)
row=torch.cat(rows),
col=torch.cat(cols), sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
value=torch.cat(values, dim=0) if has_value else None, sparse_sizes[1] += tensor.sparse_size(1)
sparse_size=sparse_size, nnz += tensor.nnz()
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None, row = torch.cat(rows, dim=0)
is_sorted=False,
) col = torch.cat(cols, dim=0)
elif dim == (0, 1) or dim == (1, 0): value: Optional[torch.Tensor] = None
for tensor in tensors: if len(values) == len(tensors):
rowptr, col, value = tensor.csr() value = torch.cat(values, dim=0)
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs] colptr: Optional[torch.Tensor] = None
cols += [col + sparse_size[1]] if len(colptrs) == len(tensors):
values += [value] colptr = torch.cat(colptrs, dim=0)
if has_row: colcount: Optional[torch.Tensor] = None
rows += [tensor.storage.row + sparse_size[0]] if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
if has_rowcount:
rowcounts += [tensor.storage.rowcount] storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
if has_colcount: colptr=colptr, colcount=colcount, csr2csc=None,
colcounts += [tensor.storage.colcount] csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
if has_colptr:
colptr = tensor.storage.colptr elif dim > 1 and dim < tensors[0].dim():
colptr = colptr if len(colptrs) == 0 else colptr[1:] values: List[torch.Tensor] = []
colptrs += [colptr + nnzs]
if has_csr2csc:
csr2cscs += [tensor.storage.csr2csc + nnzs]
if has_csc2csr:
csc2csrs += [tensor.storage.csc2csr + nnzs]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
colptr=torch.cat(colptrs) if has_colptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True,
)
elif isinstance(dim, int) and dim > 1 and dim < tensors[0].dim():
for tensor in tensors: for tensor in tensors:
values += [tensor.storage.value] value = tensor.storage.value()
if value is not None:
old_storage = tensors[0].storage values.append(value)
storage = old_storage.__class__(
row=old_storage._row, value: Optional[torch.Tensor] = None
rowptr=old_storage._rowptr, if len(values) == len(tensors):
col=old_storage._col, value = torch.cat(values, dim=dim - 1)
value=torch.cat(values, dim=dim - 1),
sparse_size=old_storage.sparse_size,
rowcount=old_storage._rowcount,
colptr=old_storage._colptr,
colcount=old_storage._colcount,
csr2csc=old_storage._csr2csc,
csc2csr=old_storage._csc2csr,
is_sorted=True,
)
return tensors[0].set_value(value, layout='coo')
else: else:
raise IndexError( raise IndexError(
(f'Dimension out of range: Expected to be in range of ' (f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]')) f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]'))
return tensors[0].__class__.from_storage(storage)
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
csr2cscs: List[torch.Tensor] = []
csc2csrs: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col + sparse_sizes[1])
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
csr2csc = tensor.storage._csr2csc
if csr2csc is not None:
csr2cscs.append(csr2csc + nnz)
csc2csr = tensor.storage._csc2csr
if csc2csr is not None:
csc2csrs.append(csc2csr + nnz)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
csr2csc: Optional[torch.Tensor] = None
if len(csr2cscs) == len(tensors):
csr2csc = torch.cat(csr2cscs, dim=0)
csc2csr: Optional[torch.Tensor] = None
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)
...@@ -199,7 +199,7 @@ class SparseStorage(object): ...@@ -199,7 +199,7 @@ class SparseStorage(object):
def set_value_(self, value: Optional[torch.Tensor], def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc2csr': if get_layout(layout) == 'csc':
value = value[self.csc2csr()] value = value[self.csc2csr()]
value = value.contiguous() value = value.contiguous()
assert value.device == self._col.device assert value.device == self._col.device
...@@ -211,7 +211,7 @@ class SparseStorage(object): ...@@ -211,7 +211,7 @@ class SparseStorage(object):
def set_value(self, value: Optional[torch.Tensor], def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None): layout: Optional[str] = None):
if value is not None: if value is not None:
if get_layout(layout) == 'csc2csr': if get_layout(layout) == 'csc':
value = value[self.csc2csr()] value = value[self.csc2csr()]
value = value.contiguous() value = value.contiguous()
assert value.device == self._col.device assert value.device == self._col.device
...@@ -384,6 +384,20 @@ class SparseStorage(object): ...@@ -384,6 +384,20 @@ class SparseStorage(object):
self._csc2csr = None self._csc2csr = None
return self return self
def num_cached_keys(self) -> int:
count = 0
if self.has_rowcount():
count += 1
if self.has_colptr():
count += 1
if self.has_colcount():
count += 1
if self.has_csr2csc():
count += 1
if self.has_csc2csr():
count += 1
return count
def copy(self): def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col, return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value, value=self._value,
......
...@@ -197,7 +197,7 @@ class SparseTensor(object): ...@@ -197,7 +197,7 @@ class SparseTensor(object):
sizes = self.sparse_sizes() sizes = self.sparse_sizes()
value = self.storage.value() value = self.storage.value()
if value is not None: if value is not None:
sizes += value.size()[1:] sizes = sizes + value.size()[1:]
return sizes return sizes
def size(self, dim: int) -> int: def size(self, dim: int) -> int:
......
import torch
import numpy as np
if torch.cuda.is_available():
import torch_sparse.unique_cuda
def unique(src):
src = src.contiguous().view(-1)
if src.is_cuda:
out, perm = torch_sparse.unique_cuda.unique(src)
else:
out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
return out, perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment