"src/array/vscode:/vscode.git/clone" did not exist on "75ec58260aae12b963bbfa2f3eba49bd0eb964cf"
Commit 18411366 authored by rusty1s's avatar rusty1s
Browse files

fix cuda context

parent 67d63167
...@@ -18,7 +18,7 @@ for library in ['_version', '_basis', '_weighting']: ...@@ -18,7 +18,7 @@ for library in ['_version', '_basis', '_weighting']:
f"{osp.dirname(__file__)}") f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_spline_conv.cuda_version() cuda_version = torch.ops.torch_spline_conv.cuda_version()
if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000: if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2]) major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else: else:
...@@ -34,8 +34,8 @@ if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover ...@@ -34,8 +34,8 @@ if torch.cuda.is_available() and cuda_version != -1: # pragma: no cover
f'matches your PyTorch install.') f'matches your PyTorch install.')
from .basis import spline_basis # noqa from .basis import spline_basis # noqa
from .weighting import spline_weighting # noqa
from .conv import spline_conv # noqa from .conv import spline_conv # noqa
from .weighting import spline_weighting # noqa
__all__ = [ __all__ = [
'spline_basis', 'spline_basis',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment