Commit fa27e1fd authored by rusty1s's avatar rusty1s
Browse files

windows fix

parent abb2e1dc
......@@ -103,12 +103,12 @@ install:
- pip3 install numpy
- pip3 install torch==${TORCH} -f https://download.pytorch.org/whl/torch_stable.html
- pip3 install flake8
- python3 setup.py install
- python3 setup.py install || python setup.py install
script:
- flake8 .
- python3 setup.py test
# - python3 setup.py bdist_wheel
- python3 setup.py test || python setup.py test
- python3 setup.py bdist_wheel || python setup.py bdist_wheel
# - ls dist
notifications:
email: false
......@@ -16,7 +16,20 @@ from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_version.so'))
cuda_version = torch.ops.torch_scatter.cuda_version()
print(cuda_version)
if cuda_version != 1 and torch.version.cuda is not None:
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
if t_major != major or t_minor != minor:
raise RuntimeError(
'Detected that PyTorch and torch_scatter were compiled with '
'different CUDA versions. PyTorch has CUDA version={}.{} and '
'torch_scatter has CUDA version={}.{}. Please reinstall the '
'torch_scatter that matches your PyTorch install.'.format(
t_major, t_minor, major, minor))
__version__ = '2.0.3'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment