index_select.py 2.72 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
from torch_scatter import gather_csr
rusty1s's avatar
rusty1s committed
3
4
5
6
7

from torch_sparse.storage import get_layout


def index_select(src, dim, idx):
rusty1s's avatar
rusty1s committed
8
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
9
10
11
12

    assert idx.dim() == 1

    if dim == 0:
rusty1s's avatar
rusty1s committed
13
        old_rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
14
15
16
        rowcount = src.storage.rowcount

        rowcount = rowcount[idx]
17

rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
        rowptr = col.new_zeros(idx.size(0) + 1)
        torch.cumsum(rowcount, dim=0, out=rowptr[1:])

        row = torch.arange(idx.size(0),
                           device=col.device).repeat_interleave(rowcount)

24
        perm = torch.arange(row.size(0), device=row.device)
rusty1s's avatar
rusty1s committed
25
        perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
26

rusty1s's avatar
rusty1s committed
27
28
29
30
31
        col = col[perm]

        if src.has_value():
            value = value[perm]

rusty1s's avatar
rusty1s committed
32
        sparse_size = torch.Size([idx.size(0), src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
36
        storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
                                        value=value, sparse_size=sparse_size,
                                        rowcount=rowcount, is_sorted=True)
rusty1s's avatar
rusty1s committed
37
38

    elif dim == 1:
39
        old_colptr, row, value = src.csc()
rusty1s's avatar
rusty1s committed
40
41
42
        colcount = src.storage.colcount

        colcount = colcount[idx]
rusty1s's avatar
rusty1s committed
43
44
45
46
47
        col = torch.arange(idx.size(0),
                           device=row.device).repeat_interleave(colcount)

        colptr = row.new_zeros(idx.size(0) + 1)
        torch.cumsum(colcount, dim=0, out=colptr[1:])
48
49

        perm = torch.arange(col.size(0), device=col.device)
rusty1s's avatar
rusty1s committed
50
        perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
51

rusty1s's avatar
rusty1s committed
52
        row = row[perm]
rusty1s's avatar
rusty1s committed
53
54
        csc2csr = (idx.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]
rusty1s's avatar
rusty1s committed
55
56
57
58

        if src.has_value():
            value = value[perm][csc2csr]

rusty1s's avatar
rusty1s committed
59
        sparse_size = torch.Size([src.sparse_size(0), idx.size(0)])
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
62
63
64
        storage = src.storage.__class__(row=row, col=col, value=value,
                                        sparse_size=sparse_size, colptr=colptr,
                                        colcount=colcount, csc2csr=csc2csr,
                                        is_sorted=True)
rusty1s's avatar
rusty1s committed
65
66

    else:
67
68
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))
rusty1s's avatar
rusty1s committed
69
70
71
72
73
74
75
76
77
78

    return src.from_storage(storage)


def index_select_nnz(src, idx, layout=None):
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = idx[src.storage.csc2csr]

rusty1s's avatar
rusty1s committed
79
80
    row, col, value = src.coo()
    row, col = row[idx], col[idx]
rusty1s's avatar
rusty1s committed
81
82
83
84
85

    if src.has_value():
        value = value[idx]

    # There is no other information we can maintain...
rusty1s's avatar
rusty1s committed
86
87
    storage = src.storage.__class__(row=row, col=col, value=value,
                                    sparse_size=src.sparse_size(),
88
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
89
90

    return src.from_storage(storage)