narrow.py 764 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch_sparse.tensor import SparseTensor


def narrow(src, dim, start, length):
    if dim  == 0:
        col, rowptr, value = src.csr()
        rowptr = rowptr.narrow(0, start=start, length=length)

        row_start, row_end = rowptr[0]
        row_length = rowptr[-1] - row_start

        col = col.narrow(0, row_start, row_length)
        row = self._row.narrow(0, row_start, row_length)




    elif dim == 0:

    else:


    pass


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    row = torch.tensor([0, 0, 1, 1], device=device)
    col = torch.tensor([1, 2, 0, 2], device=device)
    sparse_mat = SparseTensor(torch.stack([row, col], dim=0))
    print(sparse_mat)
    print(sparse_mat.to_dense())