diag.py 869 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch


def add_diag(src, value=None, k=0):
    pass


def remove_diag(src, k=0):
    index, value = src.coo()
    row, col = index

    mask = row == col if k == 0 else row == (col + k)
    inv_mask = ~mask

    index = index[:, inv_mask]

    if src.has_value():
        value = value[inv_mask]

    rowcount = None
    if src.storage.has_rowcount():
        rowcount = src.storage.rowcount.clone()
        rowcount[row[mask]] -= 1

    colcount = None
    if src.storage.has_colcount():
        colcount = src.storage.colcount.clone()
        colcount[col[mask]] -= 1

    storage = src.storage.__class__(index, value,
                                    sparse_size=src.sparse_size(),
                                    rowcount=rowcount, colcount=colcount,
                                    is_sorted=True)
    return src.__class__.from_storage(storage)