index_select.py 2.64 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch

from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu


def __arange_interleave__(start, repeat):
    assert start.device == repeat.device
    assert repeat.dtype == torch.long
    assert start.dim() == 1
    assert repeat.dim() == 1
    assert start.numel() == repeat.numel()
    if start.is_cuda:
        raise NotImplementedError
    return arange_interleave_cpu.arange_interleave(start, repeat)


def index_select(src, dim, idx):
    dim = src.dim() - dim if dim < 0 else dim

    assert idx.dim() == 1
    idx = idx.to(src.device)

    if dim == 0:
        (_, col), value = src.coo()
        rowcount = src.storage.rowcount
        rowptr = src.storage.rowptr

        rowcount = rowcount[idx]
        tmp = torch.arange(rowcount.size(0), device=rowcount.device)
        row = tmp.repeat_interleave(rowcount)
        perm = __arange_interleave__(rowptr[idx], rowcount)
        col = col[perm]
        index = torch.stack([row, col], dim=0)

        if src.has_value():
            value = value[perm]

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        rowcount=rowcount, is_sorted=True)

    elif dim == 1:
        colptr, row, value = src.csc()
        colcount = src.storage.colcount

        colcount = colcount[idx]
        tmp = torch.arange(colcount.size(0), device=row.device)
        col = tmp.repeat_interleave(colcount)
        perm = __arange_interleave__(colptr[idx], colcount)
        row = row[perm]
        csc2csr = (colcount.size(0) * row + col).argsort()
        index = torch.stack([row, col], dim=0)[:, csc2csr]

        if src.has_value():
            value = value[perm][csc2csr]

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        colcount=colcount, csc2csr=csc2csr,
                                        is_sorted=True)

    else:
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))

    return src.from_storage(storage)


def index_select_nnz(src, idx, layout=None):
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = idx[src.storage.csc2csr]

    index, value = src.coo()

    index = index[:, idx]
    if src.has_value():
        value = value[idx]

    # There is no other information we can maintain...
    storage = src.storage.__class__(index, value, src.sparse_size(),
                                    is_sorted=True)

    return src.from_storage(storage)