Commit 4242e343 authored by rusty1s's avatar rusty1s
Browse files

eye implementation

parent fc183212
......@@ -52,7 +52,7 @@ def test_set_diag(dtype, device):
print()
k = -8
print("k = ", k)
mat = mat.remove_diag(k)
mat = mat.set_diag(k)
print(mat.to_dense())
# row, col = mat.storage.index
......
......@@ -39,7 +39,10 @@ def remove_diag(src, k=0):
return src.__class__.from_storage(storage)
def set_diag(src, value=None, k=0):
def set_diag(src, values=None, k=0):
if values is not None and not src.has_value():
raise ValueError('Sparse matrix has no values')
src = src.remove_diag(k=0)
index, value = src.coo()
......@@ -63,7 +66,7 @@ def set_diag(src, value=None, k=0):
if src.has_value():
new_value = torch.new_empty((mask.size(0), ) + mask.size()[1:])
new_value[mask] = value
new_value[inv_mask] = 1
new_value[inv_mask] = values if values is not None else 1
rowcount = None
if src.storage.has_rowcount():
......
......@@ -61,6 +61,41 @@ class SparseTensor(object):
return SparseTensor.from_storage(storage)
@classmethod
def eye(self, m, n=None, device=None, no_value=True, fill_cache=False):
n = m if n is None else n
index = torch.empty((2, min(m, n)), dtype=torch.long, device=device)
torch.arange(index.size(1), out=index[0])
torch.arange(index.size(1), out=index[1])
value = None
if not no_value:
value = torch.ones(index.size(1), device=device)
rowcount = rowptr = colcount = colptr = csr2csc = csc2csr = None
if fill_cache:
rowcount = index.new_ones(m)
rowptr = torch.arange(m + 1, device=device)
colcount = index.new_ones(n)
colptr = torch.arange(n + 1, device=device)
csr2csc = torch.arange(index.size(1), device=device)
csc2csr = torch.arange(index.size(1), device=device)
storage = SparseStorage(
index,
value,
torch.Size([m, n]),
rowcount=rowcount,
rowptr=rowptr,
colcount=colcount,
colptr=colptr,
csr2csc=csr2csc,
csc2csr=csc2csr,
is_sorted=True,
)
return SparseTensor.from_storage(storage)
def __copy__(self):
return self.from_storage(self.storage)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment