narrow.py 4.52 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Tuple

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
5
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
6
7


rusty1s's avatar
rusty1s committed
8
@torch.jit.script
rusty1s's avatar
rusty1s committed
9
10
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
rusty1s's avatar
rusty1s committed
11
12
13
14
15
    if dim < 0:
        dim = src.dim() + dim

    if start < 0:
        start = src.size(dim) + start
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
    if dim == 0:
rusty1s's avatar
rusty1s committed
18
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
22
23
        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
26
27
        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
        col = col.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
        if value is not None:
rusty1s's avatar
rusty1s committed
32
            value = value.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
        sparse_sizes = torch.Size([length, src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
38
39
40
41
42
43
44
        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
    elif dim == 1:
rusty1s's avatar
rusty1s committed
47
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
rusty1s's avatar
rusty1s committed
48
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
49
        mask = (col >= start) & (col < start + length)
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
        row = row[mask]
        col = col[mask] - start
rusty1s's avatar
rusty1s committed
53

rusty1s's avatar
rusty1s committed
54
55
56
57
        if value is not None:
            value = value[mask]

        sparse_sizes = torch.Size([src.sparse_size(0), length])
rusty1s's avatar
rusty1s committed
58
59

        colptr = src.storage._colptr
rusty1s's avatar
rusty1s committed
60
61
62
63
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

rusty1s's avatar
rusty1s committed
64
65
66
        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)
rusty1s's avatar
rusty1s committed
67

rusty1s's avatar
rusty1s committed
68
69
70
71
72
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
73
74

    else:
rusty1s's avatar
rusty1s committed
75
76
77
78
79
80
81
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError

rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
                    length: Tuple[int, int]) -> SparseTensor:
    # This function builds the inverse operation of `cat_diag` and should hence
    # only be used on *diagonally stacked* sparse matrices.

    rowptr, col, value = src.csr()

    rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
    row_start = rowptr[0]
    rowptr = rowptr - row_start
    row_length = rowptr[-1]

    row = src.storage._row
    if row is not None:
        row = row.narrow(0, row_start, row_length) - start[0]

    col = col.narrow(0, row_start, row_length) - start[1]

    if value is not None:
        value = value.narrow(0, row_start, row_length)

    sparse_sizes = length

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.narrow(0, start[0], length[0])

    colptr = src.storage._colptr
    if colptr is not None:
        colptr = colptr.narrow(0, start[1], length[1] + 1)
        colptr = colptr - colptr[0]  # i.e. `row_start`

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.narrow(0, start[1], length[1])

    csr2csc = src.storage._csr2csc
    if csr2csc is not None:
        csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start

    csc2csr = src.storage._csc2csr
    if csc2csr is not None:
        csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start

    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=rowcount,
                            colptr=colptr, colcount=colcount, csr2csc=csr2csc,
                            csc2csr=csc2csr, is_sorted=True)
    return src.from_storage(storage)


rusty1s's avatar
rusty1s committed
135
136
SparseTensor.narrow = lambda self, dim, start, length: narrow(
    self, dim, start, length)
rusty1s's avatar
rusty1s committed
137
138
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
    self, start, length)