__init__.py 1.94 KB
Newer Older
quyuanhao123's avatar
quyuanhao123 committed
1
2
3
4
5
6
7
8
9
10
11
import importlib
import os.path as osp

import torch

__version__ = '1.6.0'

for library in [
        '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest',
        '_knn', '_radius'
]:
yangzhong's avatar
yangzhong committed
12
13
    cuda_spec = importlib.machinery.PathFinder().find_spec(
        f'{library}_cuda', [osp.dirname(__file__)])
quyuanhao123's avatar
quyuanhao123 committed
14
15
    cpu_spec = importlib.machinery.PathFinder().find_spec(
        f'{library}_cpu', [osp.dirname(__file__)])
yangzhong's avatar
yangzhong committed
16
    spec = cuda_spec or cpu_spec
quyuanhao123's avatar
quyuanhao123 committed
17
18
19
20
21
22
23
24
25
26
27
28
    if spec is not None:
        torch.ops.load_library(spec.origin)
    else:  # pragma: no cover
        raise ImportError(f"Could not find module '{library}_cpu' in "
                          f"{osp.dirname(__file__)}")

cuda_version = torch.ops.torch_cluster.cuda_version()
if torch.cuda.is_available() and cuda_version != -1:  # pragma: no cover
    if cuda_version < 10000:
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
yangzhong's avatar
yangzhong committed
29
30
31
32
33
34
35
36
37
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]

    if t_major != major:
        raise RuntimeError(
            f'Detected that PyTorch and torch_cluster were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_cluster has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_cluster that '
            f'matches your PyTorch install.')
quyuanhao123's avatar
quyuanhao123 committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

from .fps import fps  # noqa
from .graclus import graclus_cluster  # noqa
from .grid import grid_cluster  # noqa
from .knn import knn, knn_graph  # noqa
from .nearest import nearest  # noqa
from .radius import radius, radius_graph  # noqa
from .rw import random_walk  # noqa
from .sampler import neighbor_sampler  # noqa

__all__ = [
    'graclus_cluster',
    'grid_cluster',
    'fps',
    'nearest',
    'knn',
    'knn_graph',
    'radius',
    'radius_graph',
    'random_walk',
    'neighbor_sampler',
    '__version__',
]