diag.py 2.44 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import torch

rusty1s's avatar
rusty1s committed
3
from torch_sparse.utils import ext
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
rusty1s committed
6
def remove_diag(src, k=0):
rusty1s's avatar
rusty1s committed
7
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
8
    inv_mask = row != col if k == 0 else row != (col - k)
rusty1s's avatar
rusty1s committed
9
    new_row, new_col = row[inv_mask], col[inv_mask]
rusty1s's avatar
rusty1s committed
10
11
12
13

    if src.has_value():
        value = value[inv_mask]

rusty1s's avatar
rusty1s committed
14
15
16
    if src.storage.has_rowcount() or src.storage.has_colcount():
        mask = ~inv_mask

rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[row[mask]] -= 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[col[mask]] -= 1

rusty1s's avatar
rusty1s committed
27
    storage = src.storage.__class__(row=new_row, col=new_col, value=value,
rusty1s's avatar
rusty1s committed
28
29
30
31
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)
rusty1s's avatar
todos  
rusty1s committed
32
33


rusty1s's avatar
rusty1s committed
34
35
36
37
def set_diag(src, values=None, k=0):
    if values is not None and not src.has_value():
        raise ValueError('Sparse matrix has no values')

rusty1s's avatar
rusty1s committed
38
39
    src = src.remove_diag(k=0)

rusty1s's avatar
rusty1s committed
40
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
rusty1s committed
42
43
    mask = ext(row.is_cuda).non_diag_mask(row, col, src.size(0), src.size(1),
                                          k)
rusty1s's avatar
rusty1s committed
44
45
    inv_mask = ~mask

rusty1s's avatar
rusty1s committed
46
47
    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=src.device)
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
50
51
    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
54
55
    new_col = col.new_empty(mask.size(0))
    new_col[mask] = row
    new_col[inv_mask] = diag.add_(k)
rusty1s's avatar
rusty1s committed
56
57
58

    new_value = None
    if src.has_value():
rusty1s's avatar
rusty1s committed
59
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
rusty1s's avatar
rusty1s committed
60
        new_value[mask] = value
rusty1s's avatar
rusty1s committed
61
        new_value[inv_mask] = values if values is not None else 1
rusty1s's avatar
rusty1s committed
62
63
64
65
66
67
68
69
70
71
72

    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

rusty1s's avatar
rusty1s committed
73
    storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
rusty1s's avatar
rusty1s committed
74
75
76
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
    return src.__class__.from_storage(storage)