__init__.py 1.33 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import importlib
import os.path as osp
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
import torch
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
7
8
__version__ = '1.2.0'
expected_torch_version = (1, 4)

rusty1s's avatar
rusty1s committed
9
10
11
for library in ['_version', '_basis', '_weighting']:
    torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
        library, [osp.dirname(__file__)]).origin)
rusty1s's avatar
rusty1s committed
12
13

if torch.version.cuda is not None:  # pragma: no cover
rusty1s's avatar
rusty1s committed
14
    cuda_version = torch.ops.torch_spline_conv.cuda_version()
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

    if cuda_version == -1:
        major = minor = 0
    elif cuda_version < 10000:
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]

    if t_major != major or t_minor != minor:
        raise RuntimeError(
            f'Detected that PyTorch and torch_spline_conv were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_spline_conv has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_spline_conv that '
            f'matches your PyTorch install.')

from .basis import spline_basis  # noqa
from .weighting import spline_weighting  # noqa
from .conv import spline_conv  # noqa

__all__ = [
    'spline_basis',
    'spline_weighting',
    'spline_conv',
    '__version__',
]