diag.py 2.52 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
import torch

from torch_sparse import diag_cpu

try:
    from torch_sparse import diag_cuda
except ImportError:
    diag_cuda = None


rusty1s's avatar
rusty1s committed
11
def remove_diag(src, k=0):
rusty1s's avatar
rusty1s committed
12
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
13
    inv_mask = row != col if k == 0 else row != (col - k)
rusty1s's avatar
rusty1s committed
14
    new_row, new_col = row[inv_mask], col[inv_mask]
rusty1s's avatar
rusty1s committed
15
16
17
18

    if src.has_value():
        value = value[inv_mask]

rusty1s's avatar
rusty1s committed
19
20
21
    if src.storage.has_rowcount() or src.storage.has_colcount():
        mask = ~inv_mask

rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[row[mask]] -= 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[col[mask]] -= 1

rusty1s's avatar
rusty1s committed
32
    storage = src.storage.__class__(row=new_row, col=new_col, value=value,
rusty1s's avatar
rusty1s committed
33
34
35
36
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)
rusty1s's avatar
todos  
rusty1s committed
37
38


rusty1s's avatar
rusty1s committed
39
40
41
42
def set_diag(src, values=None, k=0):
    if values is not None and not src.has_value():
        raise ValueError('Sparse matrix has no values')

rusty1s's avatar
rusty1s committed
43
44
    src = src.remove_diag(k=0)

rusty1s's avatar
rusty1s committed
45
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
48
    func = diag_cuda if row.is_cuda else diag_cpu
    mask = func.non_diag_mask(row, col, src.size(0), src.size(1), k)
rusty1s's avatar
rusty1s committed
49
50
    inv_mask = ~mask

rusty1s's avatar
rusty1s committed
51
52
    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=src.device)
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
55
56
    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
60
    new_col = col.new_empty(mask.size(0))
    new_col[mask] = row
    new_col[inv_mask] = diag.add_(k)
rusty1s's avatar
rusty1s committed
61
62
63

    new_value = None
    if src.has_value():
rusty1s's avatar
rusty1s committed
64
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
rusty1s's avatar
rusty1s committed
65
        new_value[mask] = value
rusty1s's avatar
rusty1s committed
66
        new_value[inv_mask] = values if values is not None else 1
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
77

    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

rusty1s's avatar
rusty1s committed
78
    storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
rusty1s's avatar
rusty1s committed
79
80
81
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
    return src.__class__.from_storage(storage)