select.py 287 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow


@torch.jit.script
def select(src: SparseTensor, dim: int, idx: int):
    return narrow(src, dim, start=idx, length=1)


SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)