diag.py 3.84 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
rusty1s's avatar
rusty1s committed
11
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
12
    inv_mask = row != col if k == 0 else row != (col - k)
rusty1s's avatar
rusty1s committed
13
    new_row, new_col = row[inv_mask], col[inv_mask]
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
    if value is not None:
rusty1s's avatar
rusty1s committed
16
17
        value = value[inv_mask]

rusty1s's avatar
rusty1s committed
18
19
20
    rowcount = src.storage._rowcount
    colcount = src.storage._colcount
    if rowcount is not None or colcount is not None:
rusty1s's avatar
rusty1s committed
21
        mask = ~inv_mask
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
        if rowcount is not None:
            rowcount = rowcount.clone()
            rowcount[row[mask]] -= 1
        if colcount is not None:
            colcount = colcount.clone()
            colcount[col[mask]] -= 1

29
30
31
32
    storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
                            sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
                            colptr=None, colcount=colcount, csr2csc=None,
                            csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
33
34
35
    return src.from_storage(storage)


rusty1s's avatar
rusty1s committed
36
def set_diag(src: SparseTensor, values: Optional[Tensor] = None,
rusty1s's avatar
rusty1s committed
37
             k: int = 0) -> SparseTensor:
rusty1s's avatar
rusty1s committed
38
    src = remove_diag(src, k=k)
rusty1s's avatar
rusty1s committed
39
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
matmul  
rusty1s committed
41
42
    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
rusty1s's avatar
rusty1s committed
43
44
    inv_mask = ~mask

rusty1s's avatar
rusty1s committed
45
    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
rusty1s's avatar
rusty1s committed
46
    diag = torch.arange(start, start + num_diag, device=row.device)
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
50
    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
    new_col = col.new_empty(mask.size(0))
rusty1s's avatar
rusty1s committed
53
    new_col[mask] = col
rusty1s's avatar
rusty1s committed
54
    new_col[inv_mask] = diag.add_(k)
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
rusty1s committed
56
    new_value: Optional[Tensor] = None
rusty1s's avatar
rusty1s committed
57
    if value is not None:
rusty1s's avatar
rusty1s committed
58
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
rusty1s's avatar
rusty1s committed
59
        new_value[mask] = value
rusty1s's avatar
rusty1s committed
60
61
62
        if values is not None:
            new_value[inv_mask] = values
        else:
63
            new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
rusty1s's avatar
rusty1s committed
64
65
66
67
68
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
rusty1s's avatar
rusty1s committed
69
70
        rowcount[start:start + num_diag] += 1

rusty1s's avatar
rusty1s committed
71
72
73
    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
rusty1s's avatar
rusty1s committed
74
75
        colcount[start + k:start + num_diag + k] += 1

76
77
78
79
    storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
                            value=new_value, sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount, colptr=None, colcount=colcount,
                            csr2csc=None, csc2csr=None, is_sorted=True)
rusty1s's avatar
rusty1s committed
80
81
    return src.from_storage(storage)

rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
def fill_diag(src: SparseTensor, fill_value: float,
              k: int = 0) -> SparseTensor:
rusty1s's avatar
rusty1s committed
85
86
87
88
89
90
91
92
93
94
95
96
    num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
    if k < 0:
        num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))

    value = src.storage.value()
    if value is not None:
        sizes = [num_diag] + src.sizes()[2:]
        return set_diag(src, value.new_full(sizes, fill_value), k)
    else:
        return set_diag(src, None, k)


rusty1s's avatar
rusty1s committed
97
98
99
100
def get_diag(src: SparseTensor) -> Tensor:
    row, col, value = src.coo()

    if value is None:
rusty1s's avatar
rusty1s committed
101
        value = torch.ones(row.size(0), device=row.device)
rusty1s's avatar
rusty1s committed
102
103
104
105
106
107
108
109
110
111
112

    sizes = list(value.size())
    sizes[0] = min(src.size(0), src.size(1))

    out = value.new_zeros(sizes)

    mask = row == col
    out[row[mask]] = value[mask]
    return out


rusty1s's avatar
rusty1s committed
113
114
115
SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
    self, values, k)
rusty1s's avatar
rusty1s committed
116
117
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
    self, fill_value, k)
rusty1s's avatar
rusty1s committed
118
SparseTensor.get_diag = lambda self: get_diag(self)