narrow.py 4.61 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import copy
rusty1s's avatar
rusty1s committed
2
3
from typing import Tuple

rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5
6
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
7
8


rusty1s's avatar
rusty1s committed
9
@torch.jit.script
rusty1s's avatar
rusty1s committed
10
11
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
rusty1s's avatar
rusty1s committed
12
13
14
15
16
    if dim < 0:
        dim = src.dim() + dim

    if start < 0:
        start = src.size(dim) + start
rusty1s's avatar
rusty1s committed
17

rusty1s's avatar
rusty1s committed
18
    if dim == 0:
rusty1s's avatar
rusty1s committed
19
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
23
24
        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
        col = col.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
        if value is not None:
rusty1s's avatar
rusty1s committed
33
            value = value.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
34

rusty1s's avatar
rusty1s committed
35
        sparse_sizes = torch.Size([length, src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43
44
45
        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
    elif dim == 1:
rusty1s's avatar
rusty1s committed
48
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
rusty1s's avatar
rusty1s committed
49
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
50
        mask = (col >= start) & (col < start + length)
rusty1s's avatar
rusty1s committed
51

rusty1s's avatar
rusty1s committed
52
53
        row = row[mask]
        col = col[mask] - start
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
56
57
58
        if value is not None:
            value = value[mask]

        sparse_sizes = torch.Size([src.sparse_size(0), length])
rusty1s's avatar
rusty1s committed
59
60

        colptr = src.storage._colptr
rusty1s's avatar
rusty1s committed
61
62
63
64
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

rusty1s's avatar
rusty1s committed
65
66
67
        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)
rusty1s's avatar
rusty1s committed
68

rusty1s's avatar
rusty1s committed
69
70
71
72
73
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
74
75

    else:
rusty1s's avatar
rusty1s committed
76
77
78
79
80
81
82
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError

rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
85
86
87
88
@torch.jit.script
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
                    length: Tuple[int, int]) -> SparseTensor:
    # This function builds the inverse operation of `cat_diag` and should hence
    # only be used on *diagonally stacked* sparse matrices.
rusty1s's avatar
rusty1s committed
89
    # That's the reason why this method is marked as *private*.
rusty1s's avatar
rusty1s committed
90
91
92
93

    rowptr, col, value = src.csr()

    rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
rusty1s's avatar
rusty1s committed
94
    row_start = int(rowptr[0])
rusty1s's avatar
rusty1s committed
95
    rowptr = rowptr - row_start
rusty1s's avatar
rusty1s committed
96
    row_length = int(rowptr[-1])
rusty1s's avatar
rusty1s committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    row = src.storage._row
    if row is not None:
        row = row.narrow(0, row_start, row_length) - start[0]

    col = col.narrow(0, row_start, row_length) - start[1]

    if value is not None:
        value = value.narrow(0, row_start, row_length)

    sparse_sizes = length

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.narrow(0, start[0], length[0])

    colptr = src.storage._colptr
    if colptr is not None:
        colptr = colptr.narrow(0, start[1], length[1] + 1)
rusty1s's avatar
rusty1s committed
116
        colptr = colptr - int(colptr[0])  # i.e. `row_start`
rusty1s's avatar
rusty1s committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.narrow(0, start[1], length[1])

    csr2csc = src.storage._csr2csc
    if csr2csc is not None:
        csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start

    csc2csr = src.storage._csc2csr
    if csc2csr is not None:
        csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start

    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=sparse_sizes, rowcount=rowcount,
                            colptr=colptr, colcount=colcount, csr2csc=csr2csc,
                            csc2csr=csc2csr, is_sorted=True)
    return src.from_storage(storage)


rusty1s's avatar
rusty1s committed
137
138
SparseTensor.narrow = lambda self, dim, start, length: narrow(
    self, dim, start, length)
rusty1s's avatar
rusty1s committed
139
140
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
    self, start, length)