utils.py 713 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from typing import Any

rusty1s's avatar
rusty1s committed
3
4
import torch

rusty1s's avatar
rusty1s committed
5
6
7
8
9
try:
    from typing_extensions import Final  # noqa
except ImportError:
    from torch.jit import Final  # noqa

rusty1s's avatar
rusty1s committed
10
torch.ops.load_library('torch_sparse/convert_cpu.so')
rusty1s's avatar
rusty1s committed
11
torch.ops.load_library('torch_sparse/diag_cpu.so')
rusty1s's avatar
rusty1s committed
12
torch.ops.load_library('torch_sparse/spmm_cpu.so')
rusty1s's avatar
rusty1s committed
13
14
15

try:
    torch.ops.load_library('torch_sparse/convert_cuda.so')
rusty1s's avatar
rusty1s committed
16
    torch.ops.load_library('torch_sparse/diag_cuda.so')
rusty1s's avatar
rusty1s committed
17
18
    torch.ops.load_library('torch_sparse/spmm_cuda.so')
    torch.ops.load_library('torch_sparse/spspmm_cuda.so')
rusty1s's avatar
rusty1s committed
19
20
21
except OSError as e:
    if torch.cuda.is_available():
        raise e
rusty1s's avatar
rusty1s committed
22
23


rusty1s's avatar
rusty1s committed
24
def is_scalar(other: Any) -> bool:
rusty1s's avatar
rusty1s committed
25
    return isinstance(other, int) or isinstance(other, float)