Commit 7636e1d1 authored by rusty1s's avatar rusty1s
Browse files

diag fix

parent c86527dc
......@@ -15,23 +15,23 @@ def test_remove_diag(dtype, device):
mat.fill_cache_()
mat = mat.remove_diag()
assert mat.storage.row.tolist() == [0, 1]
assert mat.storage.col.tolist() == [1, 2]
assert mat.storage.value.tolist() == [2, 3]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 1, 0]
assert mat.storage.colcount.tolist() == [0, 1, 1]
assert mat.storage.row().tolist() == [0, 1]
assert mat.storage.col().tolist() == [1, 2]
assert mat.storage.value().tolist() == [2, 3]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 1, 0]
assert mat.storage.colcount().tolist() == [0, 1, 1]
mat = SparseTensor(row=row, col=col, value=value)
mat.fill_cache_()
mat = mat.remove_diag(k=1)
assert mat.storage.row.tolist() == [0, 2]
assert mat.storage.col.tolist() == [0, 2]
assert mat.storage.value.tolist() == [1, 4]
assert len(mat.cached_keys()) == 2
assert mat.storage.rowcount.tolist() == [1, 0, 1]
assert mat.storage.colcount.tolist() == [1, 0, 1]
assert mat.storage.row().tolist() == [0, 2]
assert mat.storage.col().tolist() == [0, 2]
assert mat.storage.value().tolist() == [1, 4]
assert mat.storage.num_cached_keys() == 2
assert mat.storage.rowcount().tolist() == [1, 0, 1]
assert mat.storage.colcount().tolist() == [1, 0, 1]
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
......@@ -40,5 +40,15 @@ def test_set_diag(dtype, device):
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
k = -8
mat = mat.set_diag(k)
mat = mat.set_diag(tensor([-8, -8], dtype, device), k=-1)
mat = mat.set_diag(tensor([-8], dtype, device), k=1)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_fill_diag(dtype, device):
row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device)
value = tensor([1, 2, 3, 4], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
mat = mat.fill_diag(-8, k=-1)
mat = mat.fill_diag(-8, k=1)
......@@ -27,7 +27,7 @@ from .narrow import narrow
from .select import select
from .index_select import index_select, index_select_nnz
from .masked_select import masked_select, masked_select_nnz
from .diag import set_diag, remove_diag
from .diag import remove_diag, set_diag, fill_diag
from .add import add, add_, add_nnz, add_nnz_
from .mul import mul, mul_, mul_nnz, mul_nnz_
from .reduce import sum, mean, min, max
......
......@@ -50,7 +50,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
@torch.jit.script
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=0)
src = remove_diag(src, k=k)
row, col, value = src.coo()
mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
......@@ -65,7 +65,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
new_row[inv_mask] = diag
new_col = col.new_empty(mask.size(0))
new_col[mask] = row
new_col[mask] = col
new_col[inv_mask] = diag.add_(k)
new_value: Optional[torch.Tensor] = None
......@@ -95,6 +95,22 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
return src.from_storage(storage)
@torch.jit.script
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor:
num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
if k < 0:
num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))
value = src.storage.value()
if value is not None:
sizes = [num_diag] + src.sizes()[2:]
return set_diag(src, value.new_full(sizes, fill_value), k)
else:
return set_diag(src, None, k)
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
self, values, k)
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
self, fill_value, k)
......@@ -197,7 +197,7 @@ class SparseTensor(object):
sizes = self.sparse_sizes()
value = self.storage.value()
if value is not None:
sizes = sizes + value.size()[1:]
sizes = list(sizes) + list(value.size())[1:]
return sizes
def size(self, dim: int) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment