__init__.py 2.07 KB
Newer Older
rusty1s's avatar
update  
rusty1s committed
1
2
3
4
import importlib
import os.path as osp

import torch
rusty1s's avatar
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
__version__ = '1.5.0'
rusty1s's avatar
update  
rusty1s committed
7
8
9
expected_torch_version = (1, 4)

try:
rusty1s's avatar
rusty1s committed
10
    for library in [
rusty1s's avatar
rusty1s committed
11
12
            '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler',
            '_nearest', '_knn', '_radius'
rusty1s's avatar
rusty1s committed
13
    ]:
rusty1s's avatar
update  
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
            library, [osp.dirname(__file__)]).origin)
except OSError as e:
    major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
    t_major, t_minor = expected_torch_version
    if major != t_major or (major == t_major and minor != t_minor):
        raise RuntimeError(
            f'Expected PyTorch version {t_major}.{t_minor} but found '
            f'version {major}.{minor}.')
    raise OSError(e)

if torch.version.cuda is not None:  # pragma: no cover
    cuda_version = torch.ops.torch_sparse.cuda_version()

    if cuda_version == -1:
        major = minor = 0
    elif cuda_version < 10000:
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]

    if t_major != major or t_minor != minor:
        raise RuntimeError(
            f'Detected that PyTorch and torch_cluster were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_cluster has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_cluster that '
            f'matches your PyTorch install.')

from .graclus import graclus_cluster  # noqa
from .grid import grid_cluster  # noqa
from .fps import fps  # noqa
rusty1s's avatar
rusty1s committed
47
48
49
from .nearest import nearest  # noqa
from .knn import knn, knn_graph  # noqa
from .radius import radius, radius_graph  # noqa
rusty1s's avatar
rusty1s committed
50
51
from .rw import random_walk  # noqa
from .sampler import neighbor_sampler  # noqa
rusty1s's avatar
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
54
55
56
__all__ = [
    'graclus_cluster',
    'grid_cluster',
    'fps',
rusty1s's avatar
rusty1s committed
57
58
59
60
61
    'nearest',
    'knn',
    'knn_graph',
    'radius',
    'radius_graph',
rusty1s's avatar
rusty1s committed
62
63
    'random_walk',
    'neighbor_sampler',
rusty1s's avatar
rusty1s committed
64
65
    '__version__',
]