__init__.py 4.1 KB
Newer Older
rusty1s's avatar
typo  
rusty1s committed
1
import os
rusty1s's avatar
reset  
rusty1s committed
2
import importlib
rusty1s's avatar
typo  
rusty1s committed
3
4
5
6
import os.path as osp

import torch

rusty1s's avatar
rusty1s committed
7
__version__ = '2.0.4'
rusty1s's avatar
clean  
rusty1s committed
8
expected_torch_version = (1, 4)
rusty1s's avatar
rusty1s committed
9
10

try:
rusty1s's avatar
rusty1s committed
11
12
13
    for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
        torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
            library, [osp.dirname(__file__)]).origin)
rusty1s's avatar
rusty1s committed
14
except OSError as e:
rusty1s's avatar
rusty1s committed
15
16
17
18
19
20
    major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
    t_major, t_minor = expected_torch_version
    if major != t_major or (major == t_major and minor != t_minor):
        raise RuntimeError(
            f'Expected PyTorch version {t_major}.{t_minor} but found '
            f'version {major}.{minor}.')
rusty1s's avatar
rusty1s committed
21
22
    raise OSError(e)
except AttributeError as e:
rusty1s's avatar
doc fix  
rusty1s committed
23
    if os.getenv('BUILD_DOCS', '0') != '1':
rusty1s's avatar
rusty1s committed
24
        raise AttributeError(e)
rusty1s's avatar
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    from .placeholder import cuda_version_placeholder
    torch.ops.torch_scatter.cuda_version = cuda_version_placeholder

    from .placeholder import scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder

    from .placeholder import segment_csr_placeholder
    from .placeholder import segment_csr_arg_placeholder
    from .placeholder import gather_csr_placeholder
    torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.gather_csr = gather_csr_placeholder

    from .placeholder import segment_coo_placeholder
    from .placeholder import segment_coo_arg_placeholder
    from .placeholder import gather_coo_placeholder
    torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.gather_coo = gather_coo_placeholder

rusty1s's avatar
doc fix  
rusty1s committed
51
52
53
54
55
56
if torch.version.cuda is not None:  # pragma: no cover
    cuda_version = torch.ops.torch_scatter.cuda_version()

    if cuda_version == -1:
        major = minor = 0
    elif cuda_version < 10000:
rusty1s's avatar
rusty1s committed
57
58
59
60
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
rusty1s committed
62
63
    if t_major != major or t_minor != minor:
        raise RuntimeError(
rusty1s's avatar
rusty1s committed
64
65
66
67
68
            f'Detected that PyTorch and torch_scatter were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_scatter has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_scatter that '
            f'matches your PyTorch install.')
rusty1s's avatar
cleaner  
rusty1s committed
69

rusty1s's avatar
typo  
rusty1s committed
70
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
rusty1s's avatar
typos  
rusty1s committed
71
                      scatter_max, scatter)  # noqa
rusty1s's avatar
typo  
rusty1s committed
72
73
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
                          segment_min_csr, segment_max_csr, segment_csr,
rusty1s's avatar
typos  
rusty1s committed
74
                          gather_csr)  # noqa
rusty1s's avatar
typo  
rusty1s committed
75
76
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
                          segment_min_coo, segment_max_coo, segment_coo,
rusty1s's avatar
typos  
rusty1s committed
77
                          gather_coo)  # noqa
rusty1s's avatar
typo  
rusty1s committed
78
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
rusty1s's avatar
typos  
rusty1s committed
79
                        scatter_log_softmax)  # noqa
rusty1s's avatar
typo  
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
__all__ = [
rusty1s's avatar
rusty1s committed
82
83
84
85
86
87
    'scatter_sum',
    'scatter_add',
    'scatter_mean',
    'scatter_min',
    'scatter_max',
    'scatter',
rusty1s's avatar
rusty1s committed
88
89
90
91
92
    'segment_sum_csr',
    'segment_add_csr',
    'segment_mean_csr',
    'segment_min_csr',
    'segment_max_csr',
rusty1s's avatar
rusty1s committed
93
    'segment_csr',
rusty1s's avatar
rusty1s committed
94
    'gather_csr',
rusty1s's avatar
rusty1s committed
95
96
97
98
99
100
101
    'segment_sum_coo',
    'segment_add_coo',
    'segment_mean_coo',
    'segment_min_coo',
    'segment_max_coo',
    'segment_coo',
    'gather_coo',
rusty1s's avatar
rusty1s committed
102
103
104
105
    'scatter_std',
    'scatter_logsumexp',
    'scatter_softmax',
    'scatter_log_softmax',
106
    'torch_scatter',
rusty1s's avatar
rusty1s committed
107
    '__version__',
rusty1s's avatar
rusty1s committed
108
]