index_select.py 2.57 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch

from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu


rusty1s's avatar
typo  
rusty1s committed
7
def arange_interleave(start, repeat):
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
    assert start.device == repeat.device
    assert repeat.dtype == torch.long
    assert start.dim() == 1
    assert repeat.dim() == 1
    assert start.numel() == repeat.numel()
    if start.is_cuda:
        raise NotImplementedError
    return arange_interleave_cpu.arange_interleave(start, repeat)


def index_select(src, dim, idx):
rusty1s's avatar
rusty1s committed
19
    dim = src.dim() + dim if dim < 0 else dim
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31

    assert idx.dim() == 1
    idx = idx.to(src.device)

    if dim == 0:
        (_, col), value = src.coo()
        rowcount = src.storage.rowcount
        rowptr = src.storage.rowptr

        rowcount = rowcount[idx]
        tmp = torch.arange(rowcount.size(0), device=rowcount.device)
        row = tmp.repeat_interleave(rowcount)
rusty1s's avatar
typo  
rusty1s committed
32
        perm = arange_interleave(rowptr[idx], rowcount)
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
        col = col[perm]
        index = torch.stack([row, col], dim=0)

        if src.has_value():
            value = value[perm]

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

rusty1s's avatar
rusty1s committed
41
42
        storage = src.storage.__class__(
            index, value, sparse_size, rowcount=rowcount, is_sorted=True)
rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
50

    elif dim == 1:
        colptr, row, value = src.csc()
        colcount = src.storage.colcount

        colcount = colcount[idx]
        tmp = torch.arange(colcount.size(0), device=row.device)
        col = tmp.repeat_interleave(colcount)
rusty1s's avatar
typo  
rusty1s committed
51
        perm = arange_interleave(colptr[idx], colcount)
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
        row = row[perm]
        csc2csr = (colcount.size(0) * row + col).argsort()
        index = torch.stack([row, col], dim=0)[:, csc2csr]

        if src.has_value():
            value = value[perm][csc2csr]

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

rusty1s's avatar
rusty1s committed
61
62
63
64
65
66
67
        storage = src.storage.__class__(
            index,
            value,
            sparse_size,
            colcount=colcount,
            csc2csr=csc2csr,
            is_sorted=True)
rusty1s's avatar
rusty1s committed
68
69

    else:
rusty1s's avatar
rusty1s committed
70
71
        storage = src.storage.apply_value(lambda x: x.index_select(
            dim - 1, idx))
rusty1s's avatar
rusty1s committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    return src.from_storage(storage)


def index_select_nnz(src, idx, layout=None):
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = idx[src.storage.csc2csr]

    index, value = src.coo()

    index = index[:, idx]
    if src.has_value():
        value = value[idx]

    # There is no other information we can maintain...
rusty1s's avatar
rusty1s committed
89
90
    storage = src.storage.__class__(
        index, value, src.sparse_size(), is_sorted=True)
rusty1s's avatar
rusty1s committed
91
92

    return src.from_storage(storage)