masked_select.py 2.43 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch

from torch_sparse.storage import get_layout


def masked_select(src, dim, mask):
rusty1s's avatar
rusty1s committed
7
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
8
9
10
11
12

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
rusty1s's avatar
rusty1s committed
13
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
14
15
16
        rowcount = src.storage.rowcount

        rowcount = rowcount[mask]
rusty1s's avatar
rusty1s committed
17
18
19
20
21

        mask = mask[row]
        row = torch.arange(rowcount.size(0),
                           device=row.device).repeat_interleave(rowcount)
        col = col[mask]
rusty1s's avatar
rusty1s committed
22
23

        if src.has_value():
rusty1s's avatar
rusty1s committed
24
            value = value[mask]
rusty1s's avatar
rusty1s committed
25
26
27

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

rusty1s's avatar
rusty1s committed
28
29
30
        storage = src.storage.__class__(row=row, col=col, value=value,
                                        sparse_size=sparse_size,
                                        rowcount=rowcount, is_sorted=True)
rusty1s's avatar
rusty1s committed
31
32

    elif dim == 1:
rusty1s's avatar
rusty1s committed
33
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
34
        csr2csc = src.storage.csr2csc
rusty1s's avatar
rusty1s committed
35
        row, col = row[csr2csc], col[csr2csc]
rusty1s's avatar
rusty1s committed
36
37
38
        colcount = src.storage.colcount

        colcount = colcount[mask]
rusty1s's avatar
rusty1s committed
39
40
41
42
43

        mask = mask[col]
        col = torch.arange(colcount.size(0),
                           device=col.device).repeat_interleave(colcount)
        row = row[mask]
rusty1s's avatar
rusty1s committed
44
        csc2csr = (colcount.size(0) * row + col).argsort()
rusty1s's avatar
rusty1s committed
45
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
46
47

        if src.has_value():
rusty1s's avatar
rusty1s committed
48
            value = value[csr2csc][mask][csc2csr]
rusty1s's avatar
rusty1s committed
49
50
51

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

rusty1s's avatar
rusty1s committed
52
53
54
55
        storage = src.storage.__class__(row=row, col=col, value=value,
                                        sparse_size=sparse_size,
                                        colcount=colcount, csc2csr=csc2csr,
                                        is_sorted=True)
rusty1s's avatar
rusty1s committed
56
57
58

    else:
        idx = mask.nonzero().view(-1)
rusty1s's avatar
rusty1s committed
59
60
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))
rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
68
69
70

    return src.from_storage(storage)


def masked_select_nnz(src, mask, layout=None):
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
        mask = mask[src.storage.csc2csr]

rusty1s's avatar
rusty1s committed
71
72
    row, col, value = src.coo()
    row, col = row[mask], col[mask]
rusty1s's avatar
rusty1s committed
73
74
75
76
77

    if src.has_value():
        value = value[mask]

    # There is no other information we can maintain...
rusty1s's avatar
rusty1s committed
78
79
80
    storage = src.storage.__class__(row=row, col=col, value=value,
                                    sparse_size=src.sparse_size(),
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
81
82

    return src.from_storage(storage)