index_select.py 2.63 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch

from torch_sparse.storage import get_layout
import torch_sparse.arange_interleave_cpu as arange_interleave_cpu


rusty1s's avatar
typo  
rusty1s committed
7
def arange_interleave(start, repeat):
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
    assert start.device == repeat.device
    assert repeat.dtype == torch.long
    assert start.dim() == 1
    assert repeat.dim() == 1
    assert start.numel() == repeat.numel()
    if start.is_cuda:
        raise NotImplementedError
    return arange_interleave_cpu.arange_interleave(start, repeat)


def index_select(src, dim, idx):
    dim = src.dim() - dim if dim < 0 else dim

    assert idx.dim() == 1
    idx = idx.to(src.device)

    if dim == 0:
        (_, col), value = src.coo()
        rowcount = src.storage.rowcount
        rowptr = src.storage.rowptr

        rowcount = rowcount[idx]
        tmp = torch.arange(rowcount.size(0), device=rowcount.device)
        row = tmp.repeat_interleave(rowcount)
rusty1s's avatar
typo  
rusty1s committed
32
        perm = arange_interleave(rowptr[idx], rowcount)
rusty1s's avatar
rusty1s committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        col = col[perm]
        index = torch.stack([row, col], dim=0)

        if src.has_value():
            value = value[perm]

        sparse_size = torch.Size([rowcount.size(0), src.sparse_size(1)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        rowcount=rowcount, is_sorted=True)

    elif dim == 1:
        colptr, row, value = src.csc()
        colcount = src.storage.colcount

        colcount = colcount[idx]
        tmp = torch.arange(colcount.size(0), device=row.device)
        col = tmp.repeat_interleave(colcount)
rusty1s's avatar
typo  
rusty1s committed
51
        perm = arange_interleave(colptr[idx], colcount)
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        row = row[perm]
        csc2csr = (colcount.size(0) * row + col).argsort()
        index = torch.stack([row, col], dim=0)[:, csc2csr]

        if src.has_value():
            value = value[perm][csc2csr]

        sparse_size = torch.Size([src.sparse_size(0), colcount.size(0)])

        storage = src.storage.__class__(index, value, sparse_size,
                                        colcount=colcount, csc2csr=csc2csr,
                                        is_sorted=True)

    else:
        storage = src.storage.apply_value(
            lambda x: x.index_select(dim - 1, idx))

    return src.from_storage(storage)


def index_select_nnz(src, idx, layout=None):
    assert idx.dim() == 1

    if get_layout(layout) == 'csc':
        idx = idx[src.storage.csc2csr]

    index, value = src.coo()

    index = index[:, idx]
    if src.has_value():
        value = value[idx]

    # There is no other information we can maintain...
    storage = src.storage.__class__(index, value, src.sparse_size(),
                                    is_sorted=True)

    return src.from_storage(storage)