diag.py 2.47 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch

from torch_sparse import diag_cpu

try:
    from torch_sparse import diag_cuda
except ImportError:
    diag_cuda = None


rusty1s's avatar
rusty1s committed
11
12
13
14
def remove_diag(src, k=0):
    index, value = src.coo()
    row, col = index

rusty1s's avatar
rusty1s committed
15
    inv_mask = row != col if k == 0 else row != (col - k)
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21

    index = index[:, inv_mask]

    if src.has_value():
        value = value[inv_mask]

rusty1s's avatar
rusty1s committed
22
23
24
    if src.storage.has_rowcount() or src.storage.has_colcount():
        mask = ~inv_mask

rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[row[mask]] -= 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[col[mask]] -= 1

    storage = src.storage.__class__(index, value,
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)
rusty1s's avatar
todos  
rusty1s committed
40
41


rusty1s's avatar
rusty1s committed
42
43
44
45
def set_diag(src, values=None, k=0):
    if values is not None and not src.has_value():
        raise ValueError('Sparse matrix has no values')

rusty1s's avatar
rusty1s committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    src = src.remove_diag(k=0)

    index, value = src.coo()

    func = diag_cuda if index.is_cuda else diag_cpu
    mask = func.non_diag_mask(index, src.size(0), src.size(1), k)
    inv_mask = ~mask

    new_index = index.new_empty((2, mask.size(0)))
    new_index[:, mask] = index

    num_diag = mask.numel() - index.size(1)
    start = -k if k < 0 else 0

    diag_row = torch.arange(start, start + num_diag, device=src.device)
    new_index[0, inv_mask] = diag_row
    diag_col = diag_row.add_(k)
    new_index[1, inv_mask] = diag_col

    new_value = None
    if src.has_value():
        new_value = torch.new_empty((mask.size(0), ) + mask.size()[1:])
        new_value[mask] = value
rusty1s's avatar
rusty1s committed
69
        new_value[inv_mask] = values if values is not None else 1
rusty1s's avatar
rusty1s committed
70
71
72
73
74
75
76
77
78
79
80

    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

rusty1s's avatar
typos  
rusty1s committed
81
    storage = src.storage.__class__(new_index, new_value,
rusty1s's avatar
rusty1s committed
82
83
84
85
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)