index_select.py 2.66 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch

from torch_sparse.storage import get_layout


def index_select(src, dim, idx):
rusty1s's avatar
rusty1s committed
7
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
8
9
10
11

    assert idx.dim() == 1

    if dim == 0:
12
        (row, col), value = src.coo()
rusty1s's avatar
rusty1s committed
13
        rowcount = src.storage.rowcount
14
        old_rowptr = src.storage.rowptr
rusty1s's avatar
rusty1s committed
15
16
17
18

        rowcount = rowcount[idx]
        tmp = torch.arange(rowcount.size(0), device=rowcount.device)
        row = tmp.repeat_interleave(rowcount)
19
20
21
22
23
24

        # Creates an "arange interleave" tensor of col indices.
        rowptr = torch.cat([row.new_zeros(1), rowcount.cumsum(0)], dim=0)
        perm = torch.arange(row.size(0), device=row.device)
        perm += (old_rowptr[idx] - rowptr[:-1])[row]

rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
        col = col[perm]
        index = torch.stack([row, col], dim=0)

        if src.has_value():
            value = value[perm]

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

33
34
35
        storage = src.storage.__class__(index, value, sparse_size,
                                        rowcount=rowcount, rowptr=rowptr,
                                        is_sorted=True)
rusty1s's avatar
rusty1s committed
36
37

    elif dim == 1:
38
        old_colptr, row, value = src.csc()
rusty1s's avatar
rusty1s committed
39
40
41
42
43
        colcount = src.storage.colcount

        colcount = colcount[idx]
        tmp = torch.arange(colcount.size(0), device=row.device)
        col = tmp.repeat_interleave(colcount)
44
45
46
47
48
49

        # Creates an "arange interleave" tensor of row indices.
        colptr = torch.cat([col.new_zeros(1), colcount.cumsum(0)], dim=0)
        perm = torch.arange(col.size(0), device=col.device)
        perm += (old_colptr[idx] - colptr[:-1])[col]

rusty1s's avatar
rusty1s committed
50
51
52
53
54
55
56
57
58
        row = row[perm]
        csc2csr = (colcount.size(0) * row + col).argsort()
        index = torch.stack([row, col], dim=0)[:, csc2csr]

        if src.has_value():
            value = value[perm][csc2csr]

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

59
        storage = src.storage.__class__(index, value, sparse_size,
rusty1s's avatar
typo  
rusty1s committed
60
61
                                        colcount=colcount, colptr=colptr,
                                        csc2csr=csc2csr, is_sorted=True)
rusty1s's avatar
rusty1s committed
62
63

    else:
64
65
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

    return src.from_storage(storage)


def index_select_nnz(src, idx, layout=None):
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = idx[src.storage.csc2csr]

    index, value = src.coo()

    index = index[:, idx]
    if src.has_value():
        value = value[idx]

    # There is no other information we can maintain...
83
84
    storage = src.storage.__class__(index, value, src.sparse_size(),
                                    is_sorted=True)
rusty1s's avatar
rusty1s committed
85
86

    return src.from_storage(storage)