Commit eb9f24b5 authored by rusty1s's avatar rusty1s
Browse files

build separate cpu and cuda images

parent f2790b22
import os
import os.path as osp
import glob
import os.path as osp
from itertools import product
from setuptools import setup, find_packages
import torch
......@@ -8,34 +9,34 @@ from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True
suffices = ['cuda']
if os.getenv('FORCE_CPU', '0') == '1':
WITH_CUDA = False
suffices = ['cpu']
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
Extension = CppExtension
extensions = []
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main, suffix in product(main_files, suffices):
define_macros = []
extra_compile_args = {'cxx': ['-O2']}
extra_link_args = ['-s']
if WITH_CUDA:
Extension = CUDAExtension
if suffix == 'cuda':
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2']
extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
......@@ -43,11 +44,12 @@ def get_extensions():
sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path):
if suffix == 'cuda' and osp.exists(path):
sources += [path]
Extension = CppExtension if suffix == 'cpu' else CUDAExtension
extension = Extension(
'torch_spline_conv._' + name,
f'torch_spline_conv._{name}_{suffix}',
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
......@@ -84,8 +86,7 @@ setup(
tests_require=tests_require,
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
},
packages=find_packages(),
)
......@@ -5,9 +5,11 @@ import torch
__version__ = '1.2.0'
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
for library in ['_version', '_basis', '_weighting']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
if torch.version.cuda is not None: # pragma: no cover
cuda_version = torch.ops.torch_spline_conv.cuda_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment