select.py 272 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow


rusty1s's avatar
rusty1s committed
5
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
rusty1s's avatar
rusty1s committed
6
7
8
9
    return narrow(src, dim, start=idx, length=1)


SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)