__init__.py 4.16 KB
Newer Older
rusty1s's avatar
typo  
rusty1s committed
1
import os
rusty1s's avatar
reset  
rusty1s committed
2
import importlib
rusty1s's avatar
typo  
rusty1s committed
3
4
5
6
import os.path as osp

import torch

rusty1s's avatar
rusty1s committed
7
__version__ = '2.0.3'
rusty1s's avatar
clean  
rusty1s committed
8
expected_torch_version = (1, 4)
rusty1s's avatar
rusty1s committed
9
10

try:
rusty1s's avatar
rusty1s committed
11
12
13
    for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
        torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
            library, [osp.dirname(__file__)]).origin)
rusty1s's avatar
rusty1s committed
14
except OSError as e:
rusty1s's avatar
clean  
rusty1s committed
15
16
17
18
19
    if 'undefined symbol' in str(e):
        major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
        t_major, t_minor = expected_torch_version
        if major != t_major or (major == t_major and minor != t_minor):
            raise RuntimeError(
rusty1s's avatar
rusty1s committed
20
21
                f'Expected PyTorch version {t_major}.{t_minor} but found '
                f'version {major}.{minor}.')
rusty1s's avatar
rusty1s committed
22
23
    raise OSError(e)
except AttributeError as e:
rusty1s's avatar
doc fix  
rusty1s committed
24
    if os.getenv('BUILD_DOCS', '0') != '1':
rusty1s's avatar
rusty1s committed
25
        raise AttributeError(e)
rusty1s's avatar
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    from .placeholder import cuda_version_placeholder
    torch.ops.torch_scatter.cuda_version = cuda_version_placeholder

    from .placeholder import scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder

    from .placeholder import segment_csr_placeholder
    from .placeholder import segment_csr_arg_placeholder
    from .placeholder import gather_csr_placeholder
    torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.gather_csr = gather_csr_placeholder

    from .placeholder import segment_coo_placeholder
    from .placeholder import segment_coo_arg_placeholder
    from .placeholder import gather_coo_placeholder
    torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.gather_coo = gather_coo_placeholder

rusty1s's avatar
doc fix  
rusty1s committed
52
53
54
55
56
57
if torch.version.cuda is not None:  # pragma: no cover
    cuda_version = torch.ops.torch_scatter.cuda_version()

    if cuda_version == -1:
        major = minor = 0
    elif cuda_version < 10000:
rusty1s's avatar
rusty1s committed
58
59
60
61
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
rusty1s's avatar
rusty1s committed
62

rusty1s's avatar
rusty1s committed
63
64
    if t_major != major or t_minor != minor:
        raise RuntimeError(
rusty1s's avatar
rusty1s committed
65
66
67
68
69
            f'Detected that PyTorch and torch_scatter were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_scatter has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_scatter that '
            f'matches your PyTorch install.')
rusty1s's avatar
cleaner  
rusty1s committed
70

rusty1s's avatar
typo  
rusty1s committed
71
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
rusty1s's avatar
typos  
rusty1s committed
72
                      scatter_max, scatter)  # noqa
rusty1s's avatar
typo  
rusty1s committed
73
74
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
                          segment_min_csr, segment_max_csr, segment_csr,
rusty1s's avatar
typos  
rusty1s committed
75
                          gather_csr)  # noqa
rusty1s's avatar
typo  
rusty1s committed
76
77
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
                          segment_min_coo, segment_max_coo, segment_coo,
rusty1s's avatar
typos  
rusty1s committed
78
                          gather_coo)  # noqa
rusty1s's avatar
typo  
rusty1s committed
79
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
rusty1s's avatar
typos  
rusty1s committed
80
                        scatter_log_softmax)  # noqa
rusty1s's avatar
typo  
rusty1s committed
81

rusty1s's avatar
rusty1s committed
82
__all__ = [
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
    'scatter_sum',
    'scatter_add',
    'scatter_mean',
    'scatter_min',
    'scatter_max',
    'scatter',
rusty1s's avatar
rusty1s committed
89
90
91
92
93
    'segment_sum_csr',
    'segment_add_csr',
    'segment_mean_csr',
    'segment_min_csr',
    'segment_max_csr',
rusty1s's avatar
rusty1s committed
94
    'segment_csr',
rusty1s's avatar
rusty1s committed
95
    'gather_csr',
rusty1s's avatar
rusty1s committed
96
97
98
99
100
101
102
    'segment_sum_coo',
    'segment_add_coo',
    'segment_mean_coo',
    'segment_min_coo',
    'segment_max_coo',
    'segment_coo',
    'gather_coo',
rusty1s's avatar
rusty1s committed
103
104
105
106
    'scatter_std',
    'scatter_logsumexp',
    'scatter_softmax',
    'scatter_log_softmax',
107
    'torch_scatter',
rusty1s's avatar
rusty1s committed
108
    '__version__',
rusty1s's avatar
rusty1s committed
109
]