Commit ac25b416 authored by rusty1s's avatar rusty1s
Browse files

fixed tests

parent a1c268a5
......@@ -22,21 +22,21 @@ def test_sparse_add(dtype, device):
mat1 = mat[:, 0:100000]
mat2 = mat[:, 100000:200000]
print(mat1.shape)
print(mat2.shape)
# print(mat1.shape)
# print(mat2.shape)
# 0.0159 to beat
t = time.perf_counter()
mat = sparse_add(mat1, mat2)
print(time.perf_counter() - t)
print(mat.nnz())
# print(time.perf_counter() - t)
# print(mat.nnz())
mat1 = mat_scipy[:, 0:100000]
mat2 = mat_scipy[:, 100000:200000]
t = time.perf_counter()
mat = mat1 + mat2
print(time.perf_counter() - t)
print(mat.nnz)
# print(time.perf_counter() - t)
# print(mat.nnz)
# mat1 + mat2
......
......@@ -8,24 +8,27 @@ from .utils import devices, tensor
@pytest.mark.parametrize('device', devices)
def test_cat(device):
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index)
row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(row=row, col=col)
mat1.fill_cache_()
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index)
row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(row=row, col=col)
mat2.fill_cache_()
out = cat([mat1, mat2], dim=0)
assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
[0, 1, 0], [1, 0, 0]]
assert len(out.storage.cached_keys()) == 2
assert out.storage.has_rowcount()
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 1
assert out.storage.has_rowcount()
out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert not out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 2
assert out.storage.has_colcount()
assert out.storage.has_colptr()
......@@ -34,9 +37,13 @@ def test_cat(device):
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]]
assert len(out.storage.cached_keys()) == 6
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
out = cat([mat1, mat1], dim=-1)
assert out.storage.value.size() == (mat1.nnz(), 8)
assert len(out.storage.cached_keys()) == 6
assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
......@@ -9,26 +9,25 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device):
index = tensor([
[0, 0, 1, 2],
[0, 1, 2, 2],
], torch.long, device)
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag()
assert mat.storage.index.tolist() == [[0, 1], [1, 2]]
assert mat.storage.row.tolist() == [0, 1]
assert mat.storage.col.tolist() == [1, 2]
assert mat.storage.value.tolist() == [2, 3]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 1, 0]
assert mat.storage.colcount.tolist() == [0, 1, 1]
mat = SparseTensor(index, value)
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag(k=1)
assert mat.storage.index.tolist() == [[0, 2], [0, 2]]
assert mat.storage.row.tolist() == [0, 2]
assert mat.storage.col.tolist() == [0, 2]
assert mat.storage.value.tolist() == [1, 4]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 0, 1]
......@@ -37,12 +36,9 @@ def test_remove_diag(dtype, device):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device):
index = tensor([
[0, 0, 9, 9],
[0, 1, 0, 1],
], torch.long, device)
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
mat = SparseTensor(row=row, col=col, value=value)
k = -8
mat = mat.set_diag(k)
......@@ -9,31 +9,37 @@ from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device):
mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.col.tolist() == [0, 1, 2]
assert mat.storage.value.tolist() == [1, 1, 1]
assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, dtype=dtype, device=device, no_value=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
mat = SparseTensor.eye(3, dtype=dtype, device=device, has_value=False)
assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.col.tolist() == [0, 1, 2]
assert mat.storage.value is None
assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, 4, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1]
assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1, 0]
assert mat.storage.col.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 5
assert mat.storage.rowcount.tolist() == [1, 1, 1]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1, 0]
assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1, 0]
assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1]
assert mat.storage.col.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 5
assert mat.storage.rowcount.tolist() == [1, 1, 1, 0]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1]
assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2]
......@@ -10,31 +10,30 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index)
assert storage.index.tolist() == index.tolist()
storage = SparseStorage(row=row, col=col)
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value is None
assert storage.sparse_size() == (2, 2)
assert storage.sparse_size == (2, 2)
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(index, value)
assert storage.index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]
storage = SparseStorage(row=row, col=col, value=value)
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.sparse_size() == (2, 2)
assert storage.sparse_size == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index)
row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(row=row, col=col)
assert storage._index.tolist() == index.tolist()
assert storage._row.tolist() == row.tolist()
assert storage._col.tolist() == col.tolist()
assert storage._value is None
assert storage._rowcount is None
......@@ -52,12 +51,15 @@ def test_caching(dtype, device):
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
]
storage = SparseStorage(index, storage.value, storage.sparse_size(),
storage.rowcount, storage.rowptr, storage.colcount,
storage.colptr, storage.csr2csc, storage.csc2csr)
storage = SparseStorage(row=row, rowptr=storage.rowptr, col=col,
value=storage.value,
sparse_size=storage.sparse_size,
rowcount=storage.rowcount, colptr=storage.colptr,
colcount=storage.colcount, csr2csc=storage.csr2csc,
csc2csr=storage.csc2csr)
assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4]
......@@ -66,12 +68,12 @@ def test_caching(dtype, device):
assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr'
'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
]
storage.clear_cache_()
assert storage._rowcount is None
assert storage._rowptr is None
assert storage._rowptr is not None
assert storage._colcount is None
assert storage._colptr is None
assert storage._csr2csc is None
......@@ -91,9 +93,9 @@ def test_caching(dtype, device):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device):
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(index, value)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.has_value()
......@@ -107,20 +109,20 @@ def test_utility(dtype, device):
storage = storage.set_value(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4]
storage.sparse_resize_(3, 3)
assert storage.sparse_size() == (3, 3)
storage = storage.sparse_resize(3, 3)
assert storage.sparse_size == (3, 3)
new_storage = copy.copy(storage)
assert new_storage != storage
assert new_storage.index.data_ptr() == storage.index.data_ptr()
assert new_storage.col.data_ptr() == storage.col.data_ptr()
new_storage = storage.clone()
assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr()
assert new_storage.col.data_ptr() != storage.col.data_ptr()
new_storage = copy.deepcopy(storage)
assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr()
assert new_storage.col.data_ptr() != storage.col.data_ptr()
storage.apply_value_(lambda x: x + 1)
assert storage.value.tolist() == [2, 3, 4, 5]
......@@ -128,29 +130,31 @@ def test_utility(dtype, device):
assert storage.value.tolist() == [3, 4, 5, 6]
storage.apply_(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long
assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long
storage = storage.apply(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long
assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long
storage.clear_cache_()
assert storage.map(lambda x: x.numel()) == [8, 4]
assert storage.map(lambda x: x.numel()) == [4, 4, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device):
index = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(index, value)
storage = SparseStorage(row=row, col=col, value=value)
assert storage.index.tolist() == index.tolist()
assert storage.row.tolist() == row.tolist()
assert storage.col.tolist() == col.tolist()
assert storage.value.tolist() == value.tolist()
assert not storage.is_coalesced()
storage = storage.coalesce()
assert storage.is_coalesced()
assert storage.index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]
assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4]
......@@ -3,17 +3,16 @@ import torch
def cat(tensors, dim):
assert len(tensors) > 0
has_row = tensors[0].storage.has_row()
has_value = tensors[0].has_value()
has_rowcount = tensors[0].storage.has_rowcount()
has_rowptr = tensors[0].storage.has_rowptr()
has_colcount = tensors[0].storage.has_colcount()
has_colptr = tensors[0].storage.has_colptr()
has_colcount = tensors[0].storage.has_colcount()
has_csr2csc = tensors[0].storage.has_csr2csc()
has_csc2csr = tensors[0].storage.has_csc2csr()
rows, cols, values, sparse_size = [], [], [], [0, 0]
rowcounts, rowptrs, colcounts, colptrs = [], [], [], []
csr2cscs, csc2csrs, nnzs = [], [], 0
rows, rowptrs, cols, values, sparse_size, nnzs = [], [], [], [], [0, 0], 0
rowcounts, colcounts, colptrs, csr2cscs, csc2csrs = [], [], [], [], []
if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim
......@@ -22,29 +21,29 @@ def cat(tensors, dim):
if dim == 0:
for tensor in tensors:
row, col, value = tensor.coo()
rows += [row + sparse_size[0]]
rowptr, col, value = tensor.csr()
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col]
values += [value]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount:
rowcounts += [tensor.storage.rowcount]
if has_rowptr:
rowptr = tensor.storage.rowptr
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs), col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None, is_sorted=True)
is_sorted=True)
elif dim == 1:
for tensor in tensors:
......@@ -52,8 +51,6 @@ def cat(tensors, dim):
rows += [row]
cols += [col + sparse_size[1]]
values += [value]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] += tensor.sparse_size(1)
if has_colcount:
colcounts += [tensor.storage.colcount]
......@@ -63,10 +60,13 @@ def cat(tensors, dim):
colptr = colptr if len(colptrs) == 0 else colptr[1:]
colptrs += [colptr + nnzs]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
row=torch.cat(rows),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
colcount=torch.cat(colcounts) if has_colcount else None,
......@@ -76,21 +76,18 @@ def cat(tensors, dim):
elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors:
row, col, value = tensor.coo()
rows += [row + sparse_size[0]]
rowptr, col, value = tensor.csr()
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col + sparse_size[1]]
values += [value] if has_value else []
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
values += [value]
if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount:
rowcounts += [tensor.storage.rowcount]
if has_rowptr:
rowptr = tensor.storage.rowptr
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
if has_colcount:
colcounts += [tensor.storage.colcount]
......@@ -105,16 +102,19 @@ def cat(tensors, dim):
if has_csc2csr:
csc2csrs += [tensor.storage.csc2csr + nnzs]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz()
storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0),
row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True,
......@@ -130,7 +130,8 @@ def cat(tensors, dim):
rowptr=old_storage._rowptr,
col=old_storage._col,
value=torch.cat(values, dim=dim - 1),
sparse_size=old_storage.sparse_size(),
sparse_size=old_storage.sparse_size,
rowcount=old_storage._rowcount,
colptr=old_storage._colptr,
colcount=old_storage._colcount,
csr2csc=old_storage._csr2csc,
......
......@@ -11,7 +11,7 @@ except ImportError:
def remove_diag(src, k=0):
row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k)
row, col = row[inv_mask], col[inv_mask]
new_row, new_col = row[inv_mask], col[inv_mask]
if src.has_value():
value = value[inv_mask]
......@@ -29,7 +29,7 @@ def remove_diag(src, k=0):
colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1
storage = src.storage.__class__(row=row, col=col, value=value,
storage = src.storage.__class__(row=new_row, col=new_col, value=value,
sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount,
is_sorted=True)
......@@ -61,7 +61,7 @@ def set_diag(src, values=None, k=0):
new_value = None
if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + value.size()[1:])
new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1
......
......@@ -154,7 +154,7 @@ class SparseStorage(object):
idx = self.col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row + self.col
if (idx[1:] < idx[:-1]).any():
perm = idx.argsort()
perm = idx[1:].argsort()
self._row = self.row[perm]
self._col = self.col[perm]
self._value = self.value[perm] if self.has_value() else None
......@@ -313,12 +313,12 @@ class SparseStorage(object):
return self.csr2csc.argsort()
def is_coalesced(self):
idx = self.col.new_zeros(self.col.numel() + 1)
idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col
return (idx[1:] > idx[:-1]).all().item()
def coalesce(self, reduce='add'):
idx = self.col.new_zeros(self.col.numel() + 1)
idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col
mask = idx[1:] > idx[:-1]
......@@ -330,8 +330,9 @@ class SparseStorage(object):
value = self.value
if self.has_value():
idx = mask.cumsum(0).sub_(1)
value = segment_csr(idx, value, reduce=reduce)
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value
return self.__class__(row=row, col=col, value=value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment