Commit a8ec223f authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 1488dee0
......@@ -5,6 +5,7 @@ import os.path as osp
import torch
__version__ = '2.0.3'
expected_torch_version = '1.4'
try:
......@@ -17,8 +18,6 @@ except OSError as e:
'{}.{}.'.format(torch_version, major, minor))
raise OSError(e)
cuda_version = torch.ops.torch_scatter.cuda_version()
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
scatter_max, scatter)
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
......@@ -30,6 +29,7 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
scatter_log_softmax)
cuda_version = torch.ops.torch_scatter.cuda_version()
if cuda_version != -1 and torch.version.cuda is not None: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
......@@ -45,10 +45,6 @@ if cuda_version != -1 and torch.version.cuda is not None: # pragma: no cover
'torch_scatter has CUDA version={}.{}. Please reinstall the '
'torch_scatter that matches your PyTorch install.'.format(
t_major, t_minor, major, minor))
else:
cuda_version = None
__version__ = '2.0.3'
__all__ = [
'scatter_sum',
......@@ -76,6 +72,5 @@ __all__ = [
'scatter_softmax',
'scatter_log_softmax',
'torch_scatter',
'cuda_version',
'__version__',
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment