Commit ac25b416 authored by rusty1s's avatar rusty1s
Browse files

fixed tests

parent a1c268a5
...@@ -22,21 +22,21 @@ def test_sparse_add(dtype, device): ...@@ -22,21 +22,21 @@ def test_sparse_add(dtype, device):
mat1 = mat[:, 0:100000] mat1 = mat[:, 0:100000]
mat2 = mat[:, 100000:200000] mat2 = mat[:, 100000:200000]
print(mat1.shape) # print(mat1.shape)
print(mat2.shape) # print(mat2.shape)
# 0.0159 to beat # 0.0159 to beat
t = time.perf_counter() t = time.perf_counter()
mat = sparse_add(mat1, mat2) mat = sparse_add(mat1, mat2)
print(time.perf_counter() - t) # print(time.perf_counter() - t)
print(mat.nnz()) # print(mat.nnz())
mat1 = mat_scipy[:, 0:100000] mat1 = mat_scipy[:, 0:100000]
mat2 = mat_scipy[:, 100000:200000] mat2 = mat_scipy[:, 100000:200000]
t = time.perf_counter() t = time.perf_counter()
mat = mat1 + mat2 mat = mat1 + mat2
print(time.perf_counter() - t) # print(time.perf_counter() - t)
print(mat.nnz) # print(mat.nnz)
# mat1 + mat2 # mat1 + mat2
......
...@@ -8,24 +8,27 @@ from .utils import devices, tensor ...@@ -8,24 +8,27 @@ from .utils import devices, tensor
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_cat(device): def test_cat(device):
index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device) row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
mat1 = SparseTensor(index) mat1 = SparseTensor(row=row, col=col)
mat1.fill_cache_() mat1.fill_cache_()
index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device) row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2 = SparseTensor(index) mat2 = SparseTensor(row=row, col=col)
mat2.fill_cache_() mat2.fill_cache_()
out = cat([mat1, mat2], dim=0) out = cat([mat1, mat2], dim=0)
assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0], assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
[0, 1, 0], [1, 0, 0]] [0, 1, 0], [1, 0, 0]]
assert len(out.storage.cached_keys()) == 2 assert out.storage.has_row()
assert out.storage.has_rowcount()
assert out.storage.has_rowptr() assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 1
assert out.storage.has_rowcount()
out = cat([mat1, mat2], dim=1) out = cat([mat1, mat2], dim=1)
assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1], assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
assert out.storage.has_row()
assert not out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 2 assert len(out.storage.cached_keys()) == 2
assert out.storage.has_colcount() assert out.storage.has_colcount()
assert out.storage.has_colptr() assert out.storage.has_colptr()
...@@ -34,9 +37,13 @@ def test_cat(device): ...@@ -34,9 +37,13 @@ def test_cat(device):
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
[0, 0, 0, 1, 0]] [0, 0, 0, 1, 0]]
assert len(out.storage.cached_keys()) == 6 assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device)) mat1.set_value_(torch.randn((mat1.nnz(), 4), device=device))
out = cat([mat1, mat1], dim=-1) out = cat([mat1, mat1], dim=-1)
assert out.storage.value.size() == (mat1.nnz(), 8) assert out.storage.value.size() == (mat1.nnz(), 8)
assert len(out.storage.cached_keys()) == 6 assert out.storage.has_row()
assert out.storage.has_rowptr()
assert len(out.storage.cached_keys()) == 5
...@@ -9,26 +9,25 @@ from .utils import dtypes, devices, tensor ...@@ -9,26 +9,25 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device): def test_remove_diag(dtype, device):
index = tensor([ row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
[0, 0, 1, 2],
[0, 1, 2, 2],
], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device) value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value) mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_() mat.fill_cache_()
mat = mat.remove_diag() mat = mat.remove_diag()
assert mat.storage.index.tolist() == [[0, 1], [1, 2]] assert mat.storage.row.tolist() == [0, 1]
assert mat.storage.col.tolist() == [1, 2]
assert mat.storage.value.tolist() == [2, 3] assert mat.storage.value.tolist() == [2, 3]
assert len(mat.cached_keys()) == 2 assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 1, 0] assert mat.storage.rowcount.tolist() == [1, 1, 0]
assert mat.storage.colcount.tolist() == [0, 1, 1] assert mat.storage.colcount.tolist() == [0, 1, 1]
mat = SparseTensor(index, value) mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_() mat.fill_cache_()
mat = mat.remove_diag(k=1) mat = mat.remove_diag(k=1)
assert mat.storage.index.tolist() == [[0, 2], [0, 2]] assert mat.storage.row.tolist() == [0, 2]
assert mat.storage.col.tolist() == [0, 2]
assert mat.storage.value.tolist() == [1, 4] assert mat.storage.value.tolist() == [1, 4]
assert len(mat.cached_keys()) == 2 assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 0, 1] assert mat.storage.rowcount.tolist() == [1, 0, 1]
...@@ -37,12 +36,9 @@ def test_remove_diag(dtype, device): ...@@ -37,12 +36,9 @@ def test_remove_diag(dtype, device):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_set_diag(dtype, device): def test_set_diag(dtype, device):
index = tensor([ row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
[0, 0, 9, 9],
[0, 1, 0, 1],
], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device) value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value) mat = SparseTensor(row=row, col=col, value=value)
k = -8 k = -8
mat = mat.set_diag(k) mat = mat.set_diag(k)
...@@ -9,31 +9,37 @@ from .utils import dtypes, devices ...@@ -9,31 +9,37 @@ from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device): def test_eye(dtype, device):
mat = SparseTensor.eye(3, dtype=dtype, device=device) mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]] assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.col.tolist() == [0, 1, 2]
assert mat.storage.value.tolist() == [1, 1, 1] assert mat.storage.value.tolist() == [1, 1, 1]
assert len(mat.cached_keys()) == 0 assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, dtype=dtype, device=device, no_value=True) mat = SparseTensor.eye(3, dtype=dtype, device=device, has_value=False)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]] assert mat.storage.row.tolist() == [0, 1, 2]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.col.tolist() == [0, 1, 2]
assert mat.storage.value is None assert mat.storage.value is None
assert len(mat.cached_keys()) == 0 assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, 4, dtype=dtype, device=device, fill_cache=True) mat = SparseTensor.eye(3, 4, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]] assert mat.storage.row.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3] assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1, 0] assert mat.storage.col.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 5
assert mat.storage.rowcount.tolist() == [1, 1, 1]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3, 3] assert mat.storage.colptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1, 0]
assert mat.storage.csr2csc.tolist() == [0, 1, 2] assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2] assert mat.storage.csc2csr.tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, dtype=dtype, device=device, fill_cache=True) mat = SparseTensor.eye(4, 3, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]] assert mat.storage.row.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1, 0]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3, 3] assert mat.storage.rowptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1] assert mat.storage.col.tolist() == [0, 1, 2]
assert len(mat.cached_keys()) == 5
assert mat.storage.rowcount.tolist() == [1, 1, 1, 0]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3] assert mat.storage.colptr.tolist() == [0, 1, 2, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1]
assert mat.storage.csr2csc.tolist() == [0, 1, 2] assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2] assert mat.storage.csc2csr.tolist() == [0, 1, 2]
...@@ -10,31 +10,30 @@ from .utils import dtypes, devices, tensor ...@@ -10,31 +10,30 @@ from .utils import dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_storage(dtype, device): def test_storage(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device) row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index) storage = SparseStorage(row=row, col=col)
assert storage.index.tolist() == index.tolist()
assert storage.row.tolist() == [0, 0, 1, 1] assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1] assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value is None assert storage.value is None
assert storage.sparse_size() == (2, 2) assert storage.sparse_size == (2, 2)
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device) row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([2, 1, 4, 3], dtype, device) value = tensor([2, 1, 4, 3], dtype, device)
storage = SparseStorage(index, value) storage = SparseStorage(row=row, col=col, value=value)
assert storage.index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]
assert storage.row.tolist() == [0, 0, 1, 1] assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1] assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4] assert storage.value.tolist() == [1, 2, 3, 4]
assert storage.sparse_size() == (2, 2) assert storage.sparse_size == (2, 2)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_caching(dtype, device): def test_caching(dtype, device):
index = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device) row, col = tensor([[0, 0, 1, 1], [0, 1, 0, 1]], torch.long, device)
storage = SparseStorage(index) storage = SparseStorage(row=row, col=col)
assert storage._index.tolist() == index.tolist() assert storage._row.tolist() == row.tolist()
assert storage._col.tolist() == col.tolist()
assert storage._value is None assert storage._value is None
assert storage._rowcount is None assert storage._rowcount is None
...@@ -52,12 +51,15 @@ def test_caching(dtype, device): ...@@ -52,12 +51,15 @@ def test_caching(dtype, device):
assert storage._csr2csc.tolist() == [0, 2, 1, 3] assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3] assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [ assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr' 'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
] ]
storage = SparseStorage(index, storage.value, storage.sparse_size(), storage = SparseStorage(row=row, rowptr=storage.rowptr, col=col,
storage.rowcount, storage.rowptr, storage.colcount, value=storage.value,
storage.colptr, storage.csr2csc, storage.csc2csr) sparse_size=storage.sparse_size,
rowcount=storage.rowcount, colptr=storage.colptr,
colcount=storage.colcount, csr2csc=storage.csr2csc,
csc2csr=storage.csc2csr)
assert storage._rowcount.tolist() == [2, 2] assert storage._rowcount.tolist() == [2, 2]
assert storage._rowptr.tolist() == [0, 2, 4] assert storage._rowptr.tolist() == [0, 2, 4]
...@@ -66,12 +68,12 @@ def test_caching(dtype, device): ...@@ -66,12 +68,12 @@ def test_caching(dtype, device):
assert storage._csr2csc.tolist() == [0, 2, 1, 3] assert storage._csr2csc.tolist() == [0, 2, 1, 3]
assert storage._csc2csr.tolist() == [0, 2, 1, 3] assert storage._csc2csr.tolist() == [0, 2, 1, 3]
assert storage.cached_keys() == [ assert storage.cached_keys() == [
'rowcount', 'rowptr', 'colcount', 'colptr', 'csr2csc', 'csc2csr' 'rowcount', 'colptr', 'colcount', 'csr2csc', 'csc2csr'
] ]
storage.clear_cache_() storage.clear_cache_()
assert storage._rowcount is None assert storage._rowcount is None
assert storage._rowptr is None assert storage._rowptr is not None
assert storage._colcount is None assert storage._colcount is None
assert storage._colptr is None assert storage._colptr is None
assert storage._csr2csc is None assert storage._csr2csc is None
...@@ -91,9 +93,9 @@ def test_caching(dtype, device): ...@@ -91,9 +93,9 @@ def test_caching(dtype, device):
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_utility(dtype, device): def test_utility(dtype, device):
index = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device) row, col = tensor([[0, 0, 1, 1], [1, 0, 1, 0]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device) value = tensor([1, 2, 3, 4], dtype, device)
storage = SparseStorage(index, value) storage = SparseStorage(row=row, col=col, value=value)
assert storage.has_value() assert storage.has_value()
...@@ -107,20 +109,20 @@ def test_utility(dtype, device): ...@@ -107,20 +109,20 @@ def test_utility(dtype, device):
storage = storage.set_value(value, layout='coo') storage = storage.set_value(value, layout='coo')
assert storage.value.tolist() == [1, 2, 3, 4] assert storage.value.tolist() == [1, 2, 3, 4]
storage.sparse_resize_(3, 3) storage = storage.sparse_resize(3, 3)
assert storage.sparse_size() == (3, 3) assert storage.sparse_size == (3, 3)
new_storage = copy.copy(storage) new_storage = copy.copy(storage)
assert new_storage != storage assert new_storage != storage
assert new_storage.index.data_ptr() == storage.index.data_ptr() assert new_storage.col.data_ptr() == storage.col.data_ptr()
new_storage = storage.clone() new_storage = storage.clone()
assert new_storage != storage assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr() assert new_storage.col.data_ptr() != storage.col.data_ptr()
new_storage = copy.deepcopy(storage) new_storage = copy.deepcopy(storage)
assert new_storage != storage assert new_storage != storage
assert new_storage.index.data_ptr() != storage.index.data_ptr() assert new_storage.col.data_ptr() != storage.col.data_ptr()
storage.apply_value_(lambda x: x + 1) storage.apply_value_(lambda x: x + 1)
assert storage.value.tolist() == [2, 3, 4, 5] assert storage.value.tolist() == [2, 3, 4, 5]
...@@ -128,29 +130,31 @@ def test_utility(dtype, device): ...@@ -128,29 +130,31 @@ def test_utility(dtype, device):
assert storage.value.tolist() == [3, 4, 5, 6] assert storage.value.tolist() == [3, 4, 5, 6]
storage.apply_(lambda x: x.to(torch.long)) storage.apply_(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long assert storage.value.dtype == torch.long
storage = storage.apply(lambda x: x.to(torch.long)) storage = storage.apply(lambda x: x.to(torch.long))
assert storage.index.dtype == torch.long assert storage.col.dtype == torch.long
assert storage.value.dtype == torch.long assert storage.value.dtype == torch.long
storage.clear_cache_() storage.clear_cache_()
assert storage.map(lambda x: x.numel()) == [8, 4] assert storage.map(lambda x: x.numel()) == [4, 4, 4]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_coalesce(dtype, device): def test_coalesce(dtype, device):
index = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device) row, col = tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], torch.long, device)
value = tensor([1, 1, 1, 3, 4], dtype, device) value = tensor([1, 1, 1, 3, 4], dtype, device)
storage = SparseStorage(index, value) storage = SparseStorage(row=row, col=col, value=value)
assert storage.index.tolist() == index.tolist() assert storage.row.tolist() == row.tolist()
assert storage.col.tolist() == col.tolist()
assert storage.value.tolist() == value.tolist() assert storage.value.tolist() == value.tolist()
assert not storage.is_coalesced() assert not storage.is_coalesced()
storage = storage.coalesce() storage = storage.coalesce()
assert storage.is_coalesced() assert storage.is_coalesced()
assert storage.index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] assert storage.row.tolist() == [0, 0, 1, 1]
assert storage.col.tolist() == [0, 1, 0, 1]
assert storage.value.tolist() == [1, 2, 3, 4] assert storage.value.tolist() == [1, 2, 3, 4]
...@@ -3,17 +3,16 @@ import torch ...@@ -3,17 +3,16 @@ import torch
def cat(tensors, dim): def cat(tensors, dim):
assert len(tensors) > 0 assert len(tensors) > 0
has_row = tensors[0].storage.has_row()
has_value = tensors[0].has_value() has_value = tensors[0].has_value()
has_rowcount = tensors[0].storage.has_rowcount() has_rowcount = tensors[0].storage.has_rowcount()
has_rowptr = tensors[0].storage.has_rowptr()
has_colcount = tensors[0].storage.has_colcount()
has_colptr = tensors[0].storage.has_colptr() has_colptr = tensors[0].storage.has_colptr()
has_colcount = tensors[0].storage.has_colcount()
has_csr2csc = tensors[0].storage.has_csr2csc() has_csr2csc = tensors[0].storage.has_csr2csc()
has_csc2csr = tensors[0].storage.has_csc2csr() has_csc2csr = tensors[0].storage.has_csc2csr()
rows, cols, values, sparse_size = [], [], [], [0, 0] rows, rowptrs, cols, values, sparse_size, nnzs = [], [], [], [], [0, 0], 0
rowcounts, rowptrs, colcounts, colptrs = [], [], [], [] rowcounts, colcounts, colptrs, csr2cscs, csc2csrs = [], [], [], [], []
csr2cscs, csc2csrs, nnzs = [], [], 0
if isinstance(dim, int): if isinstance(dim, int):
dim = tensors[0].dim() + dim if dim < 0 else dim dim = tensors[0].dim() + dim if dim < 0 else dim
...@@ -22,29 +21,29 @@ def cat(tensors, dim): ...@@ -22,29 +21,29 @@ def cat(tensors, dim):
if dim == 0: if dim == 0:
for tensor in tensors: for tensor in tensors:
row, col, value = tensor.coo() rowptr, col, value = tensor.csr()
rows += [row + sparse_size[0]] rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col] cols += [col]
values += [value] values += [value]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1)) if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount: if has_rowcount:
rowcounts += [tensor.storage.rowcount] rowcounts += [tensor.storage.rowcount]
if has_rowptr: sparse_size[0] += tensor.sparse_size(0)
rowptr = tensor.storage.rowptr sparse_size[1] = max(sparse_size[1], tensor.sparse_size(1))
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
nnzs += tensor.nnz() nnzs += tensor.nnz()
storage = tensors[0].storage.__class__( storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0), row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs), col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None, value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size, sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None, rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None, is_sorted=True) is_sorted=True)
elif dim == 1: elif dim == 1:
for tensor in tensors: for tensor in tensors:
...@@ -52,8 +51,6 @@ def cat(tensors, dim): ...@@ -52,8 +51,6 @@ def cat(tensors, dim):
rows += [row] rows += [row]
cols += [col + sparse_size[1]] cols += [col + sparse_size[1]]
values += [value] values += [value]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] += tensor.sparse_size(1)
if has_colcount: if has_colcount:
colcounts += [tensor.storage.colcount] colcounts += [tensor.storage.colcount]
...@@ -63,10 +60,13 @@ def cat(tensors, dim): ...@@ -63,10 +60,13 @@ def cat(tensors, dim):
colptr = colptr if len(colptrs) == 0 else colptr[1:] colptr = colptr if len(colptrs) == 0 else colptr[1:]
colptrs += [colptr + nnzs] colptrs += [colptr + nnzs]
sparse_size[0] = max(sparse_size[0], tensor.sparse_size(0))
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz() nnzs += tensor.nnz()
storage = tensors[0].storage.__class__( storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0), row=torch.cat(rows),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None, value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size, sparse_size=sparse_size,
colcount=torch.cat(colcounts) if has_colcount else None, colcount=torch.cat(colcounts) if has_colcount else None,
...@@ -76,21 +76,18 @@ def cat(tensors, dim): ...@@ -76,21 +76,18 @@ def cat(tensors, dim):
elif dim == (0, 1) or dim == (1, 0): elif dim == (0, 1) or dim == (1, 0):
for tensor in tensors: for tensor in tensors:
row, col, value = tensor.coo() rowptr, col, value = tensor.csr()
rows += [row + sparse_size[0]] rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
cols += [col + sparse_size[1]] cols += [col + sparse_size[1]]
values += [value] if has_value else [] values += [value]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1) if has_row:
rows += [tensor.storage.row + sparse_size[0]]
if has_rowcount: if has_rowcount:
rowcounts += [tensor.storage.rowcount] rowcounts += [tensor.storage.rowcount]
if has_rowptr:
rowptr = tensor.storage.rowptr
rowptr = rowptr if len(rowptrs) == 0 else rowptr[1:]
rowptrs += [rowptr + nnzs]
if has_colcount: if has_colcount:
colcounts += [tensor.storage.colcount] colcounts += [tensor.storage.colcount]
...@@ -105,16 +102,19 @@ def cat(tensors, dim): ...@@ -105,16 +102,19 @@ def cat(tensors, dim):
if has_csc2csr: if has_csc2csr:
csc2csrs += [tensor.storage.csc2csr + nnzs] csc2csrs += [tensor.storage.csc2csr + nnzs]
sparse_size[0] += tensor.sparse_size(0)
sparse_size[1] += tensor.sparse_size(1)
nnzs += tensor.nnz() nnzs += tensor.nnz()
storage = tensors[0].storage.__class__( storage = tensors[0].storage.__class__(
torch.stack([torch.cat(rows), torch.cat(cols)], dim=0), row=torch.cat(rows) if has_row else None,
rowptr=torch.cat(rowptrs),
col=torch.cat(cols),
value=torch.cat(values, dim=0) if has_value else None, value=torch.cat(values, dim=0) if has_value else None,
sparse_size=sparse_size, sparse_size=sparse_size,
rowcount=torch.cat(rowcounts) if has_rowcount else None, rowcount=torch.cat(rowcounts) if has_rowcount else None,
rowptr=torch.cat(rowptrs) if has_rowptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
colptr=torch.cat(colptrs) if has_colptr else None, colptr=torch.cat(colptrs) if has_colptr else None,
colcount=torch.cat(colcounts) if has_colcount else None,
csr2csc=torch.cat(csr2cscs) if has_csr2csc else None, csr2csc=torch.cat(csr2cscs) if has_csr2csc else None,
csc2csr=torch.cat(csc2csrs) if has_csc2csr else None, csc2csr=torch.cat(csc2csrs) if has_csc2csr else None,
is_sorted=True, is_sorted=True,
...@@ -130,7 +130,8 @@ def cat(tensors, dim): ...@@ -130,7 +130,8 @@ def cat(tensors, dim):
rowptr=old_storage._rowptr, rowptr=old_storage._rowptr,
col=old_storage._col, col=old_storage._col,
value=torch.cat(values, dim=dim - 1), value=torch.cat(values, dim=dim - 1),
sparse_size=old_storage.sparse_size(), sparse_size=old_storage.sparse_size,
rowcount=old_storage._rowcount,
colptr=old_storage._colptr, colptr=old_storage._colptr,
colcount=old_storage._colcount, colcount=old_storage._colcount,
csr2csc=old_storage._csr2csc, csr2csc=old_storage._csr2csc,
......
...@@ -11,7 +11,7 @@ except ImportError: ...@@ -11,7 +11,7 @@ except ImportError:
def remove_diag(src, k=0): def remove_diag(src, k=0):
row, col, value = src.coo() row, col, value = src.coo()
inv_mask = row != col if k == 0 else row != (col - k) inv_mask = row != col if k == 0 else row != (col - k)
row, col = row[inv_mask], col[inv_mask] new_row, new_col = row[inv_mask], col[inv_mask]
if src.has_value(): if src.has_value():
value = value[inv_mask] value = value[inv_mask]
...@@ -29,7 +29,7 @@ def remove_diag(src, k=0): ...@@ -29,7 +29,7 @@ def remove_diag(src, k=0):
colcount = src.storage.colcount.clone() colcount = src.storage.colcount.clone()
colcount[col[mask]] -= 1 colcount[col[mask]] -= 1
storage = src.storage.__class__(row=row, col=col, value=value, storage = src.storage.__class__(row=new_row, col=new_col, value=value,
sparse_size=src.sparse_size(), sparse_size=src.sparse_size(),
rowcount=rowcount, colcount=colcount, rowcount=rowcount, colcount=colcount,
is_sorted=True) is_sorted=True)
...@@ -61,7 +61,7 @@ def set_diag(src, values=None, k=0): ...@@ -61,7 +61,7 @@ def set_diag(src, values=None, k=0):
new_value = None new_value = None
if src.has_value(): if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + value.size()[1:]) new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value new_value[mask] = value
new_value[inv_mask] = values if values is not None else 1 new_value[inv_mask] = values if values is not None else 1
......
...@@ -154,7 +154,7 @@ class SparseStorage(object): ...@@ -154,7 +154,7 @@ class SparseStorage(object):
idx = self.col.new_zeros(col.numel() + 1) idx = self.col.new_zeros(col.numel() + 1)
idx[1:] = sparse_size[1] * self.row + self.col idx[1:] = sparse_size[1] * self.row + self.col
if (idx[1:] < idx[:-1]).any(): if (idx[1:] < idx[:-1]).any():
perm = idx.argsort() perm = idx[1:].argsort()
self._row = self.row[perm] self._row = self.row[perm]
self._col = self.col[perm] self._col = self.col[perm]
self._value = self.value[perm] if self.has_value() else None self._value = self.value[perm] if self.has_value() else None
...@@ -313,12 +313,12 @@ class SparseStorage(object): ...@@ -313,12 +313,12 @@ class SparseStorage(object):
return self.csr2csc.argsort() return self.csr2csc.argsort()
def is_coalesced(self): def is_coalesced(self):
idx = self.col.new_zeros(self.col.numel() + 1) idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col idx[1:] = self.sparse_size[1] * self.row + self.col
return (idx[1:] > idx[:-1]).all().item() return (idx[1:] > idx[:-1]).all().item()
def coalesce(self, reduce='add'): def coalesce(self, reduce='add'):
idx = self.col.new_zeros(self.col.numel() + 1) idx = self.col.new_full((self.col.numel() + 1, ), -1)
idx[1:] = self.sparse_size[1] * self.row + self.col idx[1:] = self.sparse_size[1] * self.row + self.col
mask = idx[1:] > idx[:-1] mask = idx[1:] > idx[:-1]
...@@ -330,8 +330,9 @@ class SparseStorage(object): ...@@ -330,8 +330,9 @@ class SparseStorage(object):
value = self.value value = self.value
if self.has_value(): if self.has_value():
idx = mask.cumsum(0).sub_(1) ptr = mask.nonzero().flatten()
value = segment_csr(idx, value, reduce=reduce) ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
value = segment_csr(value, ptr, reduce=reduce)
value = value[0] if isinstance(value, tuple) else value value = value[0] if isinstance(value, tuple) else value
return self.__class__(row=row, col=col, value=value, return self.__class__(row=row, col=col, value=value,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment