".github/git@developer.sourcefind.cn:OpenDAS/nerfacc.git" did not exist on "143982293967696e5857679763ed3278f9225538"
Commit 19df6430 authored by Ken Leidal's avatar Ken Leidal
Browse files

build both cpu and gpu binaries so same package can run on both CPU and GPU machines

parent 981731f0
......@@ -7,7 +7,7 @@ import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
WITH_CUDA = CUDA_HOME is not None
if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
if os.getenv('FORCE_CPU', '0') == '1':
......@@ -17,42 +17,48 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
extensions = []
for with_cuda, supername in [
(False, "cpu"),
(True, "gpu"),
]:
if with_cuda and not WITH_CUDA:
continue
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
if WITH_CUDA:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
if with_cuda:
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main]
sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
if osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
if osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if with_cuda and osp.exists(path):
sources += [path]
extension = Extension(
'torch_scatter._' + name,
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
extensions += [extension]
extension = Extension(
'torch_scatter._%s_%s' % (name, supername),
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
extensions += [extension]
return extensions
......
......@@ -6,8 +6,14 @@ import torch
__version__ = '2.0.5'
if torch.cuda.is_available():
sublib = "gpu"
else:
sublib = "cpu"
try:
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
library = "%s_%s" % (library, sublib)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
except AttributeError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment