select.py 303 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow


@torch.jit.script
rusty1s's avatar
rusty1s committed
7
def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
rusty1s's avatar
rusty1s committed
8
9
10
11
    return narrow(src, dim, start=idx, length=1)


SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)