utils.py 433 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
import torch

torch.ops.load_library('torch_sparse/convert_cpu.so')
rusty1s's avatar
rusty1s committed
4
torch.ops.load_library('torch_sparse/diag_cpu.so')
rusty1s's avatar
rusty1s committed
5
6
7

try:
    torch.ops.load_library('torch_sparse/convert_cuda.so')
rusty1s's avatar
rusty1s committed
8
9
10
11
    torch.ops.load_library('torch_sparse/diag_cuda.so')
except OSError as e:
    if torch.cuda.is_available():
        raise e
rusty1s's avatar
rusty1s committed
12
13
14
15
16


def ext(is_cuda):
    name = 'torch_sparse_cuda' if is_cuda else 'torch_sparse_cpu'
    return getattr(torch.ops, name)