diag.py 2.53 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch


rusty1s's avatar
rusty1s committed
4
def remove_diag(src, k=0):
rusty1s's avatar
rusty1s committed
5
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
6
    inv_mask = row != col if k == 0 else row != (col - k)
rusty1s's avatar
rusty1s committed
7
    new_row, new_col = row[inv_mask], col[inv_mask]
rusty1s's avatar
rusty1s committed
8
9
10
11

    if src.has_value():
        value = value[inv_mask]

rusty1s's avatar
rusty1s committed
12
13
14
    if src.storage.has_rowcount() or src.storage.has_colcount():
        mask = ~inv_mask

rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
24
    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[row[mask]] -= 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[col[mask]] -= 1

rusty1s's avatar
rusty1s committed
25
    storage = src.storage.__class__(row=new_row, col=new_col, value=value,
rusty1s's avatar
rusty1s committed
26
27
28
29
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)
rusty1s's avatar
todos  
rusty1s committed
30
31


rusty1s's avatar
rusty1s committed
32
33
34
35
def set_diag(src, values=None, k=0):
    if values is not None and not src.has_value():
        raise ValueError('Sparse matrix has no values')

rusty1s's avatar
rusty1s committed
36
37
    src = src.remove_diag(k=0)

rusty1s's avatar
rusty1s committed
38
    row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
    if row.is_cuda:
        mask = torch.ops.torch_sparse_cuda.non_diag_mask(
            row, col, src.size(0), src.size(1), k)
    else:
        mask = torch.ops.torch_sparse_cpu.non_diag_mask(
            row, col, src.size(0), src.size(1), k)

rusty1s's avatar
rusty1s committed
47
48
    inv_mask = ~mask

rusty1s's avatar
rusty1s committed
49
50
    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=src.device)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
54
    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag
rusty1s's avatar
rusty1s committed
55

rusty1s's avatar
rusty1s committed
56
57
58
    new_col = col.new_empty(mask.size(0))
    new_col[mask] = row
    new_col[inv_mask] = diag.add_(k)
rusty1s's avatar
rusty1s committed
59
60
61

    new_value = None
    if src.has_value():
rusty1s's avatar
rusty1s committed
62
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
rusty1s's avatar
rusty1s committed
63
        new_value[mask] = value
rusty1s's avatar
rusty1s committed
64
        new_value[inv_mask] = values if values is not None else 1
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72
73
74
75

    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

rusty1s's avatar
rusty1s committed
76
    storage = src.storage.__class__(row=new_row, col=new_col, value=new_value,
rusty1s's avatar
rusty1s committed
77
78
79
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
    return src.__class__.from_storage(storage)