narrow.py 2.22 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch


def narrow(src, dim, start, length):
rusty1s's avatar
rusty1s committed
5
6
    dim = src.dim() + dim if dim < 0 else dim
    start = src.size(dim) + start if start < 0 else start
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
    if dim == 0:
rusty1s's avatar
rusty1s committed
9
10
        rowptr, col, value = src.csr()
        # rowptr = src.storage.rowptr
rusty1s's avatar
rusty1s committed
11
12
13
14
15

        # Maintain `rowcount`...
        rowcount = src.storage._rowcount
        if rowcount is not None:
            rowcount = rowcount.narrow(0, start=start, length=length)
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
18
19
20
        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
23
24
        row = src.storage._row
        if row is not None:
            row = row.narrow(0, row_start, row_length) - start
rusty1s's avatar
rusty1s committed
25
        col = col.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
        if src.has_value():
            value = value.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
29

rusty1s's avatar
rusty1s committed
30
        sparse_size = torch.Size([length, src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
33
34
        storage = src.storage.__class__(row=row, rowptr=rowptr, col=col,
                                        value=value, sparse_size=sparse_size,
                                        rowcount=rowcount, is_sorted=True)
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
    elif dim == 1:
rusty1s's avatar
rusty1s committed
37
        # This is faster than accessing `csc()` contrary to the `dim=0` case.
rusty1s's avatar
rusty1s committed
38
        row, col, value = src.coo()
rusty1s's avatar
rusty1s committed
39
        mask = (col >= start) & (col < start + length)
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
42
        row, col = row[mask], col[mask] - start

rusty1s's avatar
rusty1s committed
43
44
45
46
47
48
49
        # Maintain `colcount`...
        colcount = src.storage._colcount
        if colcount is not None:
            colcount = colcount.narrow(0, start=start, length=length)

        # Maintain `colptr`...
        colptr = src.storage._colptr
rusty1s's avatar
rusty1s committed
50
51
52
53
        if colptr is not None:
            colptr = colptr.narrow(0, start=start, length=length + 1)
            colptr = colptr - colptr[0]

rusty1s's avatar
rusty1s committed
54
55
        if src.has_value():
            value = value[mask]
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
        sparse_size = torch.Size([src.sparse_size(0), length])
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
60
61
        storage = src.storage.__class__(row=row, col=col, value=value,
                                        sparse_size=sparse_size, colptr=colptr,
                                        colcount=colcount, is_sorted=True)
rusty1s's avatar
rusty1s committed
62
63

    else:
rusty1s's avatar
rusty1s committed
64
65
        storage = src.storage.apply_value(
            lambda x: x.narrow(dim - 1, start, length))
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
    return src.from_storage(storage)