Commit 46dac04f authored by rusty1s's avatar rusty1s
Browse files

get diag

parent 105a60be
......@@ -52,3 +52,15 @@ def test_fill_diag(dtype, device):
mat = mat.fill_diag(-8, k=-1)
mat = mat.fill_diag(-8, k=1)
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_get_diag(dtype, device):
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device)
mat = SparseTensor(row=row, col=col, value=value)
assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]]
row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device)
mat = SparseTensor(row=row, col=col)
assert mat.get_diag().tolist() == [1, 0, 1]
......@@ -39,7 +39,7 @@ from .select import select # noqa
from .index_select import index_select, index_select_nnz # noqa
from .masked_select import masked_select, masked_select_nnz # noqa
from .permute import permute # noqa
from .diag import remove_diag, set_diag, fill_diag # noqa
from .diag import remove_diag, set_diag, fill_diag, get_diag # noqa
from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa
......@@ -75,6 +75,7 @@ __all__ = [
'remove_diag',
'set_diag',
'fill_diag',
'get_diag',
'add',
'add_',
'add_nnz',
......
from typing import Optional
import torch
from torch import Tensor
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
......@@ -31,7 +32,7 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
return src.from_storage(storage)
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
def set_diag(src: SparseTensor, values: Optional[Tensor] = None,
k: int = 0) -> SparseTensor:
src = remove_diag(src, k=k)
row, col, value = src.coo()
......@@ -51,7 +52,7 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
new_col[mask] = col
new_col[inv_mask] = diag.add_(k)
new_value: Optional[torch.Tensor] = None
new_value: Optional[Tensor] = None
if value is not None:
new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
new_value[mask] = value
......@@ -92,8 +93,25 @@ def fill_diag(src: SparseTensor, fill_value: float,
return set_diag(src, None, k)
def get_diag(src: SparseTensor) -> Tensor:
row, col, value = src.coo()
if value is None:
value = torch.ones(row.size(0))
sizes = list(value.size())
sizes[0] = min(src.size(0), src.size(1))
out = value.new_zeros(sizes)
mask = row == col
out[row[mask]] = value[mask]
return out
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
self, values, k)
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
self, fill_value, k)
SparseTensor.get_diag = lambda self: get_diag(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment