Commit 19df6430 authored by Ken Leidal's avatar Ken Leidal
Browse files

build both cpu and gpu binaries so same package can run on both CPU and GPU machines

parent 981731f0
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = CUDA_HOME is not None
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True WITH_CUDA = True
if os.getenv('FORCE_CPU', '0') == '1': if os.getenv('FORCE_CPU', '0') == '1':
...@@ -17,11 +17,18 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' ...@@ -17,11 +17,18 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions(): def get_extensions():
extensions = []
for with_cuda, supername in [
(False, "cpu"),
(True, "gpu"),
]:
if with_cuda and not WITH_CUDA:
continue
Extension = CppExtension Extension = CppExtension
define_macros = [] define_macros = []
extra_compile_args = {'cxx': []} extra_compile_args = {'cxx': []}
if WITH_CUDA: if with_cuda:
Extension = CUDAExtension Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
...@@ -31,7 +38,6 @@ def get_extensions(): ...@@ -31,7 +38,6 @@ def get_extensions():
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files: for main in main_files:
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
...@@ -42,11 +48,11 @@ def get_extensions(): ...@@ -42,11 +48,11 @@ def get_extensions():
sources += [path] sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path): if with_cuda and osp.exists(path):
sources += [path] sources += [path]
extension = Extension( extension = Extension(
'torch_scatter._' + name, 'torch_scatter._%s_%s' % (name, supername),
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
......
...@@ -6,8 +6,14 @@ import torch ...@@ -6,8 +6,14 @@ import torch
__version__ = '2.0.5' __version__ = '2.0.5'
if torch.cuda.is_available():
sublib = "gpu"
else:
sublib = "cpu"
try: try:
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
library = "%s_%s" % (library, sublib)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
except AttributeError as e: except AttributeError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment