Commit 19df6430 authored by Ken Leidal's avatar Ken Leidal
Browse files

build both cpu and gpu binaries so same package can run on both CPU and GPU machines

parent 981731f0
......@@ -7,7 +7,7 @@ import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = CUDA_HOME is not None
if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
if os.getenv('FORCE_CPU', '0') == '1':
......@@ -17,11 +17,18 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
extensions = []
for with_cuda, supername in [
(False, "cpu"),
(True, "gpu"),
]:
if with_cuda and not WITH_CUDA:
continue
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
if WITH_CUDA:
if with_cuda:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
......@@ -31,7 +38,6 @@ def get_extensions():
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
......@@ -42,11 +48,11 @@ def get_extensions():
sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path):
if with_cuda and osp.exists(path):
sources += [path]
extension = Extension(
'torch_scatter._' + name,
'torch_scatter._%s_%s' % (name, supername),
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
......
......@@ -6,8 +6,14 @@ import torch
__version__ = '2.0.5'
if torch.cuda.is_available():
sublib = "gpu"
else:
sublib = "cpu"
try:
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
library = "%s_%s" % (library, sublib)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except AttributeError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment