__init__.py 3.92 KB
Newer Older
rusty1s's avatar
typo  
rusty1s committed
1
import os
rusty1s's avatar
reset  
rusty1s committed
2
import importlib
rusty1s's avatar
typo  
rusty1s committed
3
4
5
6
import os.path as osp

import torch

rusty1s's avatar
update  
rusty1s committed
7
__version__ = '2.0.6'
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
update  
rusty1s committed
9
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
10

rusty1s's avatar
rusty1s committed
11
try:
rusty1s's avatar
rusty1s committed
12
13
    for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
        torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
rusty1s's avatar
update  
rusty1s committed
14
            f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
rusty1s's avatar
rusty1s committed
15
except AttributeError as e:
rusty1s's avatar
doc fix  
rusty1s committed
16
    if os.getenv('BUILD_DOCS', '0') != '1':
rusty1s's avatar
rusty1s committed
17
        raise AttributeError(e)
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
21
    from .placeholder import cuda_version_placeholder
    torch.ops.torch_scatter.cuda_version = cuda_version_placeholder

rusty1s's avatar
rusty1s committed
22
23
24
    from .placeholder import scatter_placeholder
    torch.ops.torch_scatter.scatter_mul = scatter_placeholder

rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    from .placeholder import scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder
    torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder

    from .placeholder import segment_csr_placeholder
    from .placeholder import segment_csr_arg_placeholder
    from .placeholder import gather_csr_placeholder
    torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
    torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder
    torch.ops.torch_scatter.gather_csr = gather_csr_placeholder

    from .placeholder import segment_coo_placeholder
    from .placeholder import segment_coo_arg_placeholder
    from .placeholder import gather_coo_placeholder
    torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
    torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder
    torch.ops.torch_scatter.gather_coo = gather_coo_placeholder

rusty1s's avatar
update  
rusty1s committed
47
if torch.cuda.is_available():  # pragma: no cover
rusty1s's avatar
doc fix  
rusty1s committed
48
49
50
51
52
    cuda_version = torch.ops.torch_scatter.cuda_version()

    if cuda_version == -1:
        major = minor = 0
    elif cuda_version < 10000:
rusty1s's avatar
rusty1s committed
53
54
55
56
        major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
    else:
        major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
    t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
    if t_major != major:
rusty1s's avatar
rusty1s committed
59
        raise RuntimeError(
rusty1s's avatar
rusty1s committed
60
61
62
63
64
            f'Detected that PyTorch and torch_scatter were compiled with '
            f'different CUDA versions. PyTorch has CUDA version '
            f'{t_major}.{t_minor} and torch_scatter has CUDA version '
            f'{major}.{minor}. Please reinstall the torch_scatter that '
            f'matches your PyTorch install.')
rusty1s's avatar
cleaner  
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
74
from .scatter import scatter_sum, scatter_add, scatter_mul  # noqa
from .scatter import scatter_mean, scatter_min, scatter_max, scatter  # noqa
from .segment_csr import segment_sum_csr, segment_add_csr  # noqa
from .segment_csr import segment_mean_csr, segment_min_csr  # noqa
from .segment_csr import segment_max_csr, segment_csr, gather_csr  # noqa
from .segment_coo import segment_sum_coo, segment_add_coo  # noqa
from .segment_coo import segment_mean_coo, segment_min_coo  # noqa
from .segment_coo import segment_max_coo, segment_coo, gather_coo  # noqa
from .composite import scatter_std, scatter_logsumexp  # noqa
rusty1s's avatar
typo  
rusty1s committed
75
from .composite import scatter_softmax, scatter_log_softmax  # noqa
rusty1s's avatar
typo  
rusty1s committed
76

rusty1s's avatar
rusty1s committed
77
__all__ = [
rusty1s's avatar
rusty1s committed
78
79
    'scatter_sum',
    'scatter_add',
rusty1s's avatar
rusty1s committed
80
    'scatter_mul',
rusty1s's avatar
rusty1s committed
81
82
83
84
    'scatter_mean',
    'scatter_min',
    'scatter_max',
    'scatter',
rusty1s's avatar
rusty1s committed
85
86
87
88
89
    'segment_sum_csr',
    'segment_add_csr',
    'segment_mean_csr',
    'segment_min_csr',
    'segment_max_csr',
rusty1s's avatar
rusty1s committed
90
    'segment_csr',
rusty1s's avatar
rusty1s committed
91
    'gather_csr',
rusty1s's avatar
rusty1s committed
92
93
94
95
96
97
98
    'segment_sum_coo',
    'segment_add_coo',
    'segment_mean_coo',
    'segment_min_coo',
    'segment_max_coo',
    'segment_coo',
    'gather_coo',
rusty1s's avatar
rusty1s committed
99
100
101
102
    'scatter_std',
    'scatter_logsumexp',
    'scatter_softmax',
    'scatter_log_softmax',
103
    'torch_scatter',
rusty1s's avatar
rusty1s committed
104
    '__version__',
rusty1s's avatar
rusty1s committed
105
]