masked_select.py 3.01 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Optional
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
4
5
import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
9
def masked_select(src: SparseTensor, dim: int,
                  mask: torch.Tensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
10
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
11
12
13
14
15

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
rusty1s's avatar
rusty1s committed
16
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
17
        rowcount = src.storage.rowcount()
rusty1s's avatar
rusty1s committed
18
19

        rowcount = rowcount[mask]
rusty1s's avatar
rusty1s committed
20
21
22
23

        mask = mask[row]
        row = torch.arange(rowcount.size(0),
                           device=row.device).repeat_interleave(rowcount)
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
        col = col[mask]
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
        if value is not None:
rusty1s's avatar
rusty1s committed
28
            value = value[mask]
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
        sparse_sizes = (rowcount.size(0), src.sparse_size(1))
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
33
34
35
36
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colcount=None, colptr=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
37
38

    elif dim == 1:
rusty1s's avatar
rusty1s committed
39
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
40
41
42
43
        csr2csc = src.storage.csr2csc()
        row = row[csr2csc]
        col = col[csr2csc]
        colcount = src.storage.colcount()
rusty1s's avatar
rusty1s committed
44
45

        colcount = colcount[mask]
rusty1s's avatar
rusty1s committed
46
47
48
49
50

        mask = mask[col]
        col = torch.arange(colcount.size(0),
                           device=col.device).repeat_interleave(colcount)
        row = row[mask]
rusty1s's avatar
rusty1s committed
51
        csc2csr = (colcount.size(0) * row + col).argsort()
rusty1s's avatar
rusty1s committed
52
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
        if value is not None:
rusty1s's avatar
rusty1s committed
55
            value = value[csr2csc][mask][csc2csr]
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
        sparse_sizes = (src.sparse_size(0), colcount.size(0))
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
60
61
62
63
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colcount=colcount, colptr=None, csr2csc=None,
                                csc2csr=csc2csr, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
64
65

    else:
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
74
75
76
        value = src.storage.value()
        if value is not None:
            idx = mask.nonzero().flatten()
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError


def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
                      layout: Optional[str] = None) -> SparseTensor:
rusty1s's avatar
rusty1s committed
77
78
79
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
80
        mask = mask[src.storage.csc2csr()]
rusty1s's avatar
rusty1s committed
81

rusty1s's avatar
rusty1s committed
82
83
    row, col, value = src.coo()
    row, col = row[mask], col[mask]
rusty1s's avatar
rusty1s committed
84

rusty1s's avatar
rusty1s committed
85
    if value is not None:
rusty1s's avatar
rusty1s committed
86
87
        value = value[mask]

rusty1s's avatar
rusty1s committed
88
89
90
    return SparseTensor(row=row, rowptr=None, col=col, value=value,
                        sparse_sizes=src.sparse_sizes(), is_sorted=True)

rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
93
94
95
96
SparseTensor.masked_select = lambda self, dim, mask: masked_select(
    self, dim, mask)
tmp = lambda src, mask, layout=None: masked_select_nnz(  # noqa
    src, mask, layout)
SparseTensor.masked_select_nnz = tmp