narrow.py 2.65 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import torch
rusty1s's avatar
rusty1s committed
2
3
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
rusty1s's avatar
rusty1s committed
4
5


rusty1s's avatar
rusty1s committed
6
@torch.jit.script
rusty1s's avatar
rusty1s committed
7
8
def narrow(src: SparseTensor, dim: int, start: int,
           length: int) -> SparseTensor:
rusty1s's avatar
rusty1s committed
9
10
    dim = src.dim() + dim if dim < 0 else dim
    start = src.size(dim) + start if start < 0 else start
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
    if dim == 0:
rusty1s's avatar
rusty1s committed
13
        rowptr, col, value = src.csr()
rusty1s's avatar
rusty1s committed
14

rusty1s's avatar
rusty1s committed
15
16
17
18
        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
22
        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
        col = col.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
        if value is not None:
rusty1s's avatar
rusty1s committed
27
            value = value.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
        sparse_sizes = torch.Size([length, src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
30

rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39
        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)

        storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colptr=None, colcount=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
    elif dim == 1:
rusty1s's avatar
rusty1s committed
42
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
rusty1s's avatar
rusty1s committed
43
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
44
        mask = (col >= start) & (col < start + length)
rusty1s's avatar
rusty1s committed
45

rusty1s's avatar
rusty1s committed
46
47
        row = row[mask]
        col = col[mask] - start
rusty1s's avatar
rusty1s committed
48

rusty1s's avatar
rusty1s committed
49
50
51
52
        if value is not None:
            value = value[mask]

        sparse_sizes = torch.Size([src.sparse_size(0), length])
rusty1s's avatar
rusty1s committed
53
54

        colptr = src.storage._colptr
rusty1s's avatar
rusty1s committed
55
56
57
58
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

rusty1s's avatar
rusty1s committed
59
60
61
        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
64
65
66
67
        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colptr=colptr, colcount=colcount, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)
rusty1s's avatar
rusty1s committed
68
69

    else:
rusty1s's avatar
rusty1s committed
70
71
72
73
74
75
76
        value = src.storage.value()
        if value is not None:
            return src.set_value(value.narrow(dim - 1, start, length),
                                 layout='coo')
        else:
            raise ValueError

rusty1s's avatar
rusty1s committed
77

rusty1s's avatar
rusty1s committed
78
79
SparseTensor.narrow = lambda self, dim, start, length: narrow(
    self, dim, start, length)