narrow.py 1.33 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
import torch


def narrow(src, dim, start, length):
rusty1s's avatar
rusty1s committed
5
6
7
    if dim == 0:
        (row, col), value = src.coo()
        rowptr, _, _ = src.csr()
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
11
12
        rowptr = rowptr.narrow(0, start=start, length=length + 1)
        row_start = rowptr[0]
        rowptr = rowptr - row_start
        row_length = rowptr[-1]
rusty1s's avatar
rusty1s committed
13

rusty1s's avatar
rusty1s committed
14
        row = row.narrow(0, row_start, row_length) - start
rusty1s's avatar
rusty1s committed
15
        col = col.narrow(0, row_start, row_length)
rusty1s's avatar
rusty1s committed
16
17
18
19
        index = torch.stack([row, col], dim=0)
        if src.has_value():
            value = value.narrow(0, row_start, row_length)
        sparse_size = torch.Size([length, src.sparse_size(1)])
rusty1s's avatar
rusty1s committed
20

rusty1s's avatar
rusty1s committed
21
22
        storage = src._storage.__class__(
            index, value, sparse_size, rowptr, is_sorted=True)
rusty1s's avatar
rusty1s committed
23

rusty1s's avatar
rusty1s committed
24
25
26
27
    elif dim == 1:
        # This is faster than accessing `csc()` in analogy to thr `dim=0` case.
        (row, col), value = src.coo()
        mask = (col >= start) & (col < start + length)
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
31
32
        index = torch.stack([row, col - start], dim=0)[:, mask]
        if src.has_value():
            value = value[mask]
        sparse_size = torch.Size([src.sparse_size(0), length])
rusty1s's avatar
rusty1s committed
33

rusty1s's avatar
rusty1s committed
34
35
        storage = src._storage.__class__(
            index, value, sparse_size, is_sorted=True)
rusty1s's avatar
rusty1s committed
36
37

    else:
rusty1s's avatar
rusty1s committed
38
39
        storage = src._storage.apply_value(lambda x: x.narrow(
            dim - 1, start, length))
rusty1s's avatar
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
    return src.__class__.from_storage(storage)