Commit e8d349ee authored by rusty1s's avatar rusty1s
Browse files

test eye

parent 4242e343
......@@ -4,13 +4,8 @@ import pytest
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.diag_cpu import non_diag_mask
from .utils import dtypes, devices, tensor
dtypes = [torch.float]
devices = ['cpu']
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_remove_diag(dtype, device):
......@@ -49,41 +44,5 @@ def test_set_diag(dtype, device):
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(index, value)
print()
k = -8
print("k = ", k)
mat = mat.set_diag(k)
print(mat.to_dense())
# row, col = mat.storage.index
# print('k', k)
# mask = row != col - k
# index = index[:, mask]
# row, col = index
# print(row)
# print(col)
mask = non_diag_mask(mat.storage.index, mat.size(0), mat.size(1), k)
print(mask)
# bla = col - row
# print(bla)
# DETECT VORZEICHEN WECHSEL
# mask = row.new_ones(index.size(1) + 3, dtype=torch.bool)
# mask[1:] = row[1:] != row[:-1]
# # mask = row[1:] != row[:-1]
# print(mask)
# mask = (row <= col)
# print(row)
# print(col)
# print(mask)
# mask = (row[1:] == row[:-1])
# print(mask)
# UNION
# idx1 = ...
# idx2 = ...
from torch_sparse import eye
from itertools import product
import pytest
from torch_sparse.tensor import SparseTensor
def test_eye():
index, value = eye(3)
assert index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert value.tolist() == [1, 1, 1]
from .utils import dtypes, devices
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_eye(dtype, device):
mat = SparseTensor.eye(3, dtype=dtype, device=device)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert mat.storage.value.tolist() == [1, 1, 1]
assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, dtype=dtype, device=device, no_value=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert mat.storage.value is None
assert len(mat.cached_keys()) == 0
mat = SparseTensor.eye(3, 4, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1, 0]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2]
mat = SparseTensor.eye(4, 3, dtype=dtype, device=device, fill_cache=True)
assert mat.storage.index.tolist() == [[0, 1, 2], [0, 1, 2]]
assert len(mat.cached_keys()) == 6
assert mat.storage.rowcount.tolist() == [1, 1, 1, 0]
assert mat.storage.rowptr.tolist() == [0, 1, 2, 3, 3]
assert mat.storage.colcount.tolist() == [1, 1, 1]
assert mat.storage.colptr.tolist() == [0, 1, 2, 3]
assert mat.storage.csr2csc.tolist() == [0, 1, 2]
assert mat.storage.csc2csr.tolist() == [0, 1, 2]
......@@ -62,30 +62,37 @@ class SparseTensor(object):
return SparseTensor.from_storage(storage)
@classmethod
def eye(self, m, n=None, device=None, no_value=True, fill_cache=False):
n = m if n is None else n
def eye(self, M, N=None, device=None, dtype=None, no_value=False,
fill_cache=False):
N = M if N is None else N
index = torch.empty((2, min(m, n)), dtype=torch.long, device=device)
index = torch.empty((2, min(M, N)), dtype=torch.long, device=device)
torch.arange(index.size(1), out=index[0])
torch.arange(index.size(1), out=index[1])
value = None
if not no_value:
value = torch.ones(index.size(1), device=device)
value = torch.ones(index.size(1), dtype=dtype, device=device)
rowcount = rowptr = colcount = colptr = csr2csc = csc2csr = None
if fill_cache:
rowcount = index.new_ones(m)
rowptr = torch.arange(m + 1, device=device)
colcount = index.new_ones(n)
colptr = torch.arange(n + 1, device=device)
rowcount = index.new_ones(M)
rowptr = torch.arange(M + 1, device=device)
if M > N:
rowcount[index.size(1):] = 0
rowptr[index.size(1) + 1:] = index.size(1)
colcount = index.new_ones(N)
colptr = torch.arange(N + 1, device=device)
if N > M:
colcount[index.size(1):] = 0
colptr[index.size(1) + 1:] = index.size(1)
csr2csc = torch.arange(index.size(1), device=device)
csc2csr = torch.arange(index.size(1), device=device)
storage = SparseStorage(
index,
value,
torch.Size([m, n]),
torch.Size([M, N]),
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment