padding.py 810 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Tuple, List

import torch
from torch_sparse.tensor import SparseTensor


def padded_index(src: SparseTensor, binptr: torch.Tensor
                 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.
                            Tensor, List[int], List[int]]:
    return torch.ops.torch_sparse.padded_index(src.storage.rowptr(),
                                               src.storage.col(),
                                               src.storage.rowcount(), binptr)


def padded_index_select(src: torch.Tensor, index: torch.Tensor,
                        fill_value: float = 0.) -> torch.Tensor:
    fill_value = torch.tensor(fill_value, dtype=src.dtype)
    return torch.ops.torch_sparse.padded_index_select(src, index, fill_value)


SparseTensor.padded_index = padded_index