Commit 925f9567 authored by rusty1s's avatar rusty1s
Browse files

cat and dim fix

parent 26aee002
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_sparse_add(dtype, device):
name = ('DIMACS10', 'citationCiteseer')[1]
mat_scipy = loadmat(f'benchmark/{name}.mat')['Problem'][0][0][2].tocsr()
mat = SparseTensor.from_scipy(mat_scipy)
mat1 = mat[:, 0:100000]
mat2 = mat[:, 100000:200000]
# print(mat1.shape)
# print(mat2.shape)
# 0.0159 to beat
t = time.perf_counter()
mat = sparse_add(mat1, mat2)
# print(time.perf_counter() - t)
# print(mat.nnz())
mat1 = mat_scipy[:, 0:100000]
mat2 = mat_scipy[:, 100000:200000]
t = time.perf_counter()
mat = mat1 + mat2
# print(time.perf_counter() - t)
# print(mat.nnz)
# mat1 + mat2
# mat1 = mat1.tocoo()
# mat2 = mat2.tocoo()
# row1, col1 = mat1.row, mat1.col
# row2, col2 = mat2.row, mat2.col
# idx1 = row1 * 100000 + col1
# idx2 = row2 * 100000 + col2
# t = time.perf_counter()
# np.union1d(idx1, idx2)
# print(time.perf_counter() - t)
# index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
# mat1 = SparseTensor(index)
# print()
# print(mat1.to_dense())
# index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
# mat2 = SparseTensor(index)
# print(mat2.to_dense())
# add(mat1, mat2)
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.cat import cat
from torch_sparse.cat import cat, cat_diag
from .utils import devices, tensor
......@@ -21,29 +21,27 @@ def test_cat(device):
[0, 1, 0], [1, 0, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 1
assert out.storage.has_rowcount()
assert out.storage.num_cached_keys() == 1
out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert not out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 2
assert out.storage.has_colcount()
assert out.storage.has_colptr()
assert out.storage.num_cached_keys() == 2
out = cat([mat1, mat2], dim=(0, 1))
out = cat_diag([mat1, mat2])
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
assert out.storage.num_cached_keys() == 5
mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
mat1 = mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
out = cat([mat1, mat1], dim=-1)
assert out.storage.value.size() == (mat1.nnz(), 8)
assert out.storage.value().size() == (mat1.nnz(), 8)
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
assert out.storage.num_cached_keys() == 5
import time
from itertools import product
from scipy.io import loadmat
import numpy as np
import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.add import sparse_add
from .utils import dtypes, devices, tensor
devices = ['cpu']
dtypes = [torch.float]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_index_select(dtype, device):
row = torch.tensor([0, 0, 1, 1, 2])
col = torch.tensor([0, 1, 1, 2, 1])
mat = SparseTensor(row=row, col=col)
print()
print(mat.to_dense())
pass
mat = mat.index_select(0, torch.tensor([0, 2]))
print(mat.to_dense())
from itertools import product
import pytest
import torch
from torch_sparse import rowptr_cpu
from .utils import tensor, devices
if torch.cuda.is_available():
from torch_sparse import rowptr_cuda
tests = [
{
'row': [0, 0, 1, 1, 1, 2, 2],
'size': 5,
'rowptr': [0, 2, 5, 7, 7, 7],
},
{
'row': [0, 0, 1, 1, 1, 4, 4],
'size': 5,
'rowptr': [0, 2, 5, 5, 5, 7],
},
{
'row': [2, 2, 4, 4],
'size': 7,
'rowptr': [0, 0, 0, 2, 2, 4, 4, 4],
},
]
def rowptr(row, size):
return (rowptr_cuda if row.is_cuda else rowptr_cpu).rowptr(row, size)
@pytest.mark.parametrize('test,device', product(tests, devices))
def test_rowptr(test, device):
row = tensor(test['row'], torch.long, device)
size = test['size']
expected = tensor(test['rowptr'], torch.long, device)
out = rowptr(row, size)
assert torch.all(out == expected)
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch_sparse.storage import SparseStorage, no_cache
from torch_sparse.storage import SparseStorage
from .utils import dtypes, devices, tensor
......@@ -79,17 +79,6 @@ def test_caching(dtype, device):
assert storage._csr2csc is None
assert storage.cached_keys() == []
with no_cache():
storage.fill_cache_()
assert storage.cached_keys() == []
@no_cache()
def do_something(storage):
return storage.fill_cache_()
storage = do_something(storage)
assert storage.cached_keys() == []
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
......
import torch
torch.ops.load_library('torch_sparse/convert_cpu.so')
torch.ops.load_library('torch_sparse/diag_cpu.so')
torch.ops.load_library('torch_sparse/spmm_cpu.so')
try:
torch.ops.load_library('torch_sparse/convert_cuda.so')
torch.ops.load_library('torch_sparse/diag_cuda.so')
torch.ops.load_library('torch_sparse/spmm_cuda.so')
torch.ops.load_library('torch_sparse/spspmm_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
from .coalesce import coalesce
from .transpose import transpose
......@@ -48,3 +33,4 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
from .matmul import (spmm_sum, spmm_add, spmm_mean, spmm_min, spmm_max, spmm,
spspmm_sum, spspmm_add, spspmm, matmul)
from .cat import cat, cat_diag
from typing import List, Optional
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def cat(tensors, dim):
@torch.jit.script
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
assert len(tensors) > 0
has_row = tensors[0].storage.has_row()
has_value = tensors[0].has_value()
has_rowcount = tensors[0].storage.has_rowcount()
has_colptr = tensors[0].storage.has_colptr()
has_colcount = tensors[0].storage.has_colcount()
has_csr2csc = tensors[0].storage.has_csr2csc()
has_csc2csr = tensors[0].storage.has_csc2csr()
rows, rowptrs, cols, values, sparse_size, nnzs = [], [], [], [], [0, 0], 0
rowcounts, colcounts, colptrs, csr2cscs, csc2csrs = [], [], [], [], []
if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
else:
dim = tuple([tensors[0].dim() + d if d < 0 else d for d in dim])
if dim < 0:
dim = tensors[0].dim() + dim
if dim == 0:
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
rowptr, col, value = tensor.csr()
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col]
values += [value]
if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount:
rowcounts += [tensor.storage.rowcount]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs), col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
is_sorted=True)
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col)
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return tensors[0].from_storage(storage)
elif dim == 1:
rows: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row, col, value = tensor.coo()
rows += [row]
cols += [col + sparse_size[1]]
values += [value]
if has_colcount:
colcounts += [tensor.storage.colcount]
if has_colptr:
colptr = tensor.storage.colptr
colptr = colptr if len(colptrs) == 0 else colptr[1:]
colptrs += [colptr + nnzs]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
row=torch.cat(rows),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None,
is_sorted=False,
)
elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors:
rowptr, col, value = tensor.csr()
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col + sparse_size[1]]
values += [value]
if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount:
rowcounts += [tensor.storage.rowcount]
if has_colcount:
colcounts += [tensor.storage.colcount]
if has_colptr:
colptr = tensor.storage.colptr
colptr = colptr if len(colptrs) == 0 else colptr[1:]
colptrs += [colptr + nnzs]
if has_csr2csc:
csr2cscs += [tensor.storage.csr2csc + nnzs]
if has_csc2csr:
csc2csrs += [tensor.storage.csc2csr + nnzs]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
colptr=torch.cat(colptrs) if has_colptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True,
)
elif isinstance(dim, int) and dim > 1 and dim < tensors[0].dim():
rows.append(row)
cols.append(tensor.storage._col + sparse_sizes[1])
if value is not None:
values.append(value)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row = torch.cat(rows, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=False)
return tensors[0].from_storage(storage)
elif dim > 1 and dim < tensors[0].dim():
values: List[torch.Tensor] = []
for tensor in tensors:
values += [tensor.storage.value]
old_storage = tensors[0].storage
storage = old_storage.__class__(
row=old_storage._row,
rowptr=old_storage._rowptr,
col=old_storage._col,
value=torch.cat(values, dim=dim - 1),
sparse_size=old_storage.sparse_size,
rowcount=old_storage._rowcount,
colptr=old_storage._colptr,
colcount=old_storage._colcount,
csr2csc=old_storage._csr2csc,
csc2csr=old_storage._csc2csr,
is_sorted=True,
)
value = tensor.storage.value()
if value is not None:
values.append(value)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=dim - 1)
return tensors[0].set_value(value, layout='coo')
else:
raise IndexError(
(f'Dimension out of range: Expected to be in range of '
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}, but got {dim}]'))
return tensors[0].__class__.from_storage(storage)
@torch.jit.script
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
assert len(tensors) > 0
rows: List[torch.Tensor] = []
rowptrs: List[torch.Tensor] = []
cols: List[torch.Tensor] = []
values: List[torch.Tensor] = []
sparse_sizes: List[int] = [0, 0]
rowcounts: List[torch.Tensor] = []
colptrs: List[torch.Tensor] = []
colcounts: List[torch.Tensor] = []
csr2cscs: List[torch.Tensor] = []
csc2csrs: List[torch.Tensor] = []
nnz: int = 0
for tensor in tensors:
row = tensor.storage._row
if row is not None:
rows.append(row + sparse_sizes[0])
rowptr = tensor.storage._rowptr
if rowptr is not None:
if len(rowptrs) > 0:
rowptr = rowptr[1:]
rowptrs.append(rowptr + nnz)
cols.append(tensor.storage._col + sparse_sizes[1])
value = tensor.storage._value
if value is not None:
values.append(value)
rowcount = tensor.storage._rowcount
if rowcount is not None:
rowcounts.append(rowcount)
colptr = tensor.storage._colptr
if colptr is not None:
if len(colptrs) > 0:
colptr = colptr[1:]
colptrs.append(colptr + nnz)
colcount = tensor.storage._colcount
if colcount is not None:
colcounts.append(colcount)
csr2csc = tensor.storage._csr2csc
if csr2csc is not None:
csr2cscs.append(csr2csc + nnz)
csc2csr = tensor.storage._csc2csr
if csc2csr is not None:
csc2csrs.append(csc2csr + nnz)
sparse_sizes[0] += tensor.sparse_size(0)
sparse_sizes[1] += tensor.sparse_size(1)
nnz += tensor.nnz()
row: Optional[torch.Tensor] = None
if len(rows) == len(tensors):
row = torch.cat(rows, dim=0)
rowptr: Optional[torch.Tensor] = None
if len(rowptrs) == len(tensors):
rowptr = torch.cat(rowptrs, dim=0)
col = torch.cat(cols, dim=0)
value: Optional[torch.Tensor] = None
if len(values) == len(tensors):
value = torch.cat(values, dim=0)
rowcount: Optional[torch.Tensor] = None
if len(rowcounts) == len(tensors):
rowcount = torch.cat(rowcounts, dim=0)
colptr: Optional[torch.Tensor] = None
if len(colptrs) == len(tensors):
colptr = torch.cat(colptrs, dim=0)
colcount: Optional[torch.Tensor] = None
if len(colcounts) == len(tensors):
colcount = torch.cat(colcounts, dim=0)
csr2csc: Optional[torch.Tensor] = None
if len(csr2cscs) == len(tensors):
csr2csc = torch.cat(csr2cscs, dim=0)
csc2csr: Optional[torch.Tensor] = None
if len(csc2csrs) == len(tensors):
csc2csr = torch.cat(csc2csrs, dim=0)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return tensors[0].from_storage(storage)
......@@ -199,7 +199,7 @@ class SparseStorage(object):
def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc2csr':
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
......@@ -211,7 +211,7 @@ class SparseStorage(object):
def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
if value is not None:
if get_layout(layout) == 'csc2csr':
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
value = value.contiguous()
assert value.device == self._col.device
......@@ -384,6 +384,20 @@ class SparseStorage(object):
self._csc2csr = None
return self
def num_cached_keys(self) -> int:
count = 0
if self.has_rowcount():
count += 1
if self.has_colptr():
count += 1
if self.has_colcount():
count += 1
if self.has_csr2csc():
count += 1
if self.has_csc2csr():
count += 1
return count
def copy(self):
return SparseStorage(row=self._row, rowptr=self._rowptr, col=self._col,
value=self._value,
......
......@@ -197,7 +197,7 @@ class SparseTensor(object):
sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes += value.size()[1:]
sizes = sizes + value.size()[1:]
return sizes
def size(self, dim: int) -> int:
......
import torch
import numpy as np
if torch.cuda.is_available():
import torch_sparse.unique_cuda
def unique(src):
src = src.contiguous().view(-1)
if src.is_cuda:
out, perm = torch_sparse.unique_cuda.unique(src)
else:
out, perm = np.unique(src.numpy(), return_index=True)
out, perm = torch.from_numpy(out), torch.from_numpy(perm)
return out, perm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment