index_select.py 3.31 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
5
6
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
                 idx: torch.Tensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
12
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
13
14
15
    assert idx.dim() == 1

    if dim == 0:
rusty1s's avatar
rusty1s committed
16
        old_rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
17
        rowcount = src.storage.rowcount()
rusty1s's avatar
rusty1s committed
18
19

        rowcount = rowcount[idx]
20

rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

27
        perm = torch.arange(row.size(0), device=row.device)
rusty1s's avatar
rusty1s committed
28
        perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
29

rusty1s's avatar
rusty1s committed
30
31
        col = col[perm]

rusty1s's avatar
rusty1s committed
32
        if value is not None:
rusty1s's avatar
rusty1s committed
33
34
            value = value[perm]

rusty1s's avatar
rusty1s committed
35
        sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
39
40
41
        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
42
43

    elif dim == 1:
44
        old_colptr, row, value = src.csc()
rusty1s's avatar
rusty1s committed
45
        colcount = src.storage.colcount()
rusty1s's avatar
rusty1s committed
46
47

        colcount = colcount[idx]
rusty1s's avatar
rusty1s committed
48
49
50

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])
51

rusty1s's avatar
rusty1s committed
52
53
54
        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

55
        perm = torch.arange(col.size(0), device=col.device)
rusty1s's avatar
rusty1s committed
56
        perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
57

rusty1s's avatar
rusty1s committed
58
        row = row[perm]
rusty1s's avatar
rusty1s committed
59
60
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
        if value is not None:
rusty1s's avatar
rusty1s committed
63
64
            value = value[perm][csc2csr]

rusty1s's avatar
rusty1s committed
65
        sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
68
69
70
71
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=csc2csr, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
72
73

    else:
rusty1s's avatar
rusty1s committed
74
75
76
77
78
79
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
rusty1s's avatar
rusty1s committed
80
81


rusty1s's avatar
rusty1s committed
82
83
84
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
                     layout: Optional[str] = None) -> SparseTensor:
rusty1s's avatar
rusty1s committed
85
86
87
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
88
        idx = src.storage.csc2csr()[idx]
rusty1s's avatar
rusty1s committed
89

rusty1s's avatar
rusty1s committed
90
91
    row, col, value = src.coo()
    row, col = row[idx], col[idx]
rusty1s's avatar
rusty1s committed
92

rusty1s's avatar
rusty1s committed
93
    if value is not None:
rusty1s's avatar
rusty1s committed
94
95
        value = value[idx]

rusty1s's avatar
rusty1s committed
96
97
98
    return SparseTensor(row=row, rowptr=None, col=col, value=value,
                        sparse_sizes=src.sparse_sizes(), is_sorted=True)

rusty1s's avatar
rusty1s committed
99

rusty1s's avatar
rusty1s committed
100
101
102
103
SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz(  # noqa
    self, idx, layout)
SparseTensor.index_select_nnz = tmp