utils.py 684 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

torch.ops.load_library('torch_sparse/convert_cpu.so')
rusty1s's avatar
rusty1s committed
4
torch.ops.load_library('torch_sparse/diag_cpu.so')
rusty1s's avatar
rusty1s committed
5
torch.ops.load_library('torch_sparse/spmm_cpu.so')
rusty1s's avatar
rusty1s committed
6
7
8

try:
    torch.ops.load_library('torch_sparse/convert_cuda.so')
rusty1s's avatar
rusty1s committed
9
    torch.ops.load_library('torch_sparse/diag_cuda.so')
rusty1s's avatar
rusty1s committed
10
11
    torch.ops.load_library('torch_sparse/spmm_cuda.so')
    torch.ops.load_library('torch_sparse/spspmm_cuda.so')
rusty1s's avatar
rusty1s committed
12
13
14
except OSError as e:
    if torch.cuda.is_available():
        raise e
rusty1s's avatar
rusty1s committed
15
16
17
18
19


def ext(is_cuda):
    name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
    return getattr(torch.ops, name)
rusty1s's avatar
rusty1s committed
20
21
22
23


def is_scalar(other):
    return isinstance(other, int) or isinstance(other, float)