"megatron/vscode:/vscode.git/clone" did not exist on "064bdc46a91aaf3e5f2d17b22cbcfb1788db0680"
select.py 272 Bytes
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
6
7
8
9
from torch_sparse.tensor import SparseTensor
from torch_sparse.narrow import narrow


def select(src: SparseTensor, dim: int, idx: int) -> SparseTensor:
    return narrow(src, dim, start=idx, length=1)


SparseTensor.select = lambda self, dim, idx: select(self, dim, idx)