index_select.py 3.25 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
5
6
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
def index_select(src: SparseTensor, dim: int,
                 idx: torch.Tensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
11
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
12
13
14
    assert idx.dim() == 1

    if dim == 0:
rusty1s's avatar
rusty1s committed
15
        old_rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
16
        rowcount = src.storage.rowcount()
rusty1s's avatar
rusty1s committed
17
18

        rowcount = rowcount[idx]
19

rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

26
        perm = torch.arange(row.size(0), device=row.device)
rusty1s's avatar
rusty1s committed
27
        perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
28

rusty1s's avatar
rusty1s committed
29
30
        col = col[perm]

rusty1s's avatar
rusty1s committed
31
        if value is not None:
rusty1s's avatar
rusty1s committed
32
33
            value = value[perm]

rusty1s's avatar
rusty1s committed
34
        sparse_sizes = (idx.size(0), src.sparse_size(1))
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
38
39
40
        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
41
42

    elif dim == 1:
43
        old_colptr, row, value = src.csc()
rusty1s's avatar
rusty1s committed
44
        colcount = src.storage.colcount()
rusty1s's avatar
rusty1s committed
45
46

        colcount = colcount[idx]
rusty1s's avatar
rusty1s committed
47
48
49

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])
50

rusty1s's avatar
rusty1s committed
51
52
53
        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

54
        perm = torch.arange(col.size(0), device=col.device)
rusty1s's avatar
rusty1s committed
55
        perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
56

rusty1s's avatar
rusty1s committed
57
        row = row[perm]
rusty1s's avatar
rusty1s committed
58
59
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
        if value is not None:
rusty1s's avatar
rusty1s committed
62
63
            value = value[perm][csc2csr]

rusty1s's avatar
rusty1s committed
64
        sparse_sizes = (src.sparse_size(0), idx.size(0))
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
67
68
69
70
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=csc2csr, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
71
72

    else:
rusty1s's avatar
rusty1s committed
73
74
75
76
77
78
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
rusty1s's avatar
rusty1s committed
79
80


rusty1s's avatar
rusty1s committed
81
82
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
                     layout: Optional[str] = None) -> SparseTensor:
rusty1s's avatar
rusty1s committed
83
84
85
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
86
        idx = src.storage.csc2csr()[idx]
rusty1s's avatar
rusty1s committed
87

rusty1s's avatar
rusty1s committed
88
89
    row, col, value = src.coo()
    row, col = row[idx], col[idx]
rusty1s's avatar
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
    if value is not None:
rusty1s's avatar
rusty1s committed
92
93
        value = value[idx]

rusty1s's avatar
rusty1s committed
94
95
96
    return SparseTensor(row=row, rowptr=None, col=col, value=value,
                        sparse_sizes=src.sparse_sizes(), is_sorted=True)

rusty1s's avatar
rusty1s committed
97

rusty1s's avatar
rusty1s committed
98
99
100
101
SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz(  # noqa
    self, idx, layout)
SparseTensor.index_select_nnz = tmp