index_select.py 3.34 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Optional

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
5
6
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
10
11
@torch.jit.script
def index_select(src: SparseTensor, dim: int,
                 idx: torch.Tensor) -> SparseTensor:
rusty1s's avatar
rusty1s committed
12
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
13
14
15
    assert idx.dim() == 1

    if dim == 0:
rusty1s's avatar
rusty1s committed
16
        old_rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
17
        rowcount = src.storage.rowcount()
rusty1s's avatar
rusty1s committed
18
19

        rowcount = rowcount[idx]
20

rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

27
        perm = torch.arange(row.size(0), device=row.device)
rusty1s's avatar
rusty1s committed
28
29
        # TODO
        # perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
30

rusty1s's avatar
rusty1s committed
31
32
        col = col[perm]

rusty1s's avatar
rusty1s committed
33
        if value is not None:
rusty1s's avatar
rusty1s committed
34
35
            value = value[perm]

rusty1s's avatar
rusty1s committed
36
        sparse_sizes = torch.Size([idx.size(0), src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
40
41
42
        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
43
44

    elif dim == 1:
45
        old_colptr, row, value = src.csc()
rusty1s's avatar
rusty1s committed
46
        colcount = src.storage.colcount()
rusty1s's avatar
rusty1s committed
47
48

        colcount = colcount[idx]
rusty1s's avatar
rusty1s committed
49
50
51

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])
52

rusty1s's avatar
rusty1s committed
53
54
55
        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

56
        perm = torch.arange(col.size(0), device=col.device)
rusty1s's avatar
rusty1s committed
57
58
        # TODO
        # perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
59

rusty1s's avatar
rusty1s committed
60
        row = row[perm]
rusty1s's avatar
rusty1s committed
61
62
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
        if value is not None:
rusty1s's avatar
rusty1s committed
65
66
            value = value[perm][csc2csr]

rusty1s's avatar
rusty1s committed
67
        sparse_sizes = torch.Size([src.sparse_size(0), idx.size(0)])
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
70
71
72
73
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=csc2csr, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
74
75

    else:
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError
rusty1s's avatar
rusty1s committed
82
83


rusty1s's avatar
rusty1s committed
84
85
86
@torch.jit.script
def index_select_nnz(src: SparseTensor, idx: torch.Tensor,
                     layout: Optional[str] = None) -> SparseTensor:
rusty1s's avatar
rusty1s committed
87
88
89
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
rusty1s's avatar
rusty1s committed
90
        idx = src.storage.csc2csr()[idx]
rusty1s's avatar
rusty1s committed
91

rusty1s's avatar
rusty1s committed
92
93
    row, col, value = src.coo()
    row, col = row[idx], col[idx]
rusty1s's avatar
rusty1s committed
94

rusty1s's avatar
rusty1s committed
95
    if value is not None:
rusty1s's avatar
rusty1s committed
96
97
        value = value[idx]

rusty1s's avatar
rusty1s committed
98
99
100
    return SparseTensor(row=row, rowptr=None, col=col, value=value,
                        sparse_sizes=src.sparse_sizes(), is_sorted=True)

rusty1s's avatar
rusty1s committed
101

rusty1s's avatar
rusty1s committed
102
103
104
105
SparseTensor.index_select = lambda self, dim, idx: index_select(self, dim, idx)
tmp = lambda self, idx, layout=None: index_select_nnz(  # noqa
    self, idx, layout)
SparseTensor.index_select_nnz = tmp