Commit eb9f24b5 authored by rusty1s's avatar rusty1s
Browse files

build separate cpu and cuda images

parent f2790b22
import os import os
import os.path as osp
import glob import glob
import os.path as osp
from itertools import product
from setuptools import setup, find_packages from setuptools import setup, find_packages
import torch import torch
...@@ -8,34 +9,34 @@ from torch.utils.cpp_extension import BuildExtension ...@@ -8,34 +9,34 @@ from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True suffices = ['cuda']
if os.getenv('FORCE_CPU', '0') == '1': if os.getenv('FORCE_CPU', '0') == '1':
WITH_CUDA = False suffices = ['cpu']
BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions(): def get_extensions():
Extension = CppExtension extensions = []
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
for main, suffix in product(main_files, suffices):
define_macros = [] define_macros = []
extra_compile_args = {'cxx': ['-O2']} extra_compile_args = {'cxx': ['-O2']}
extra_link_args = ['-s'] extra_link_args = ['-s']
if WITH_CUDA: if suffix == 'cuda':
Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2'] nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = []
for main in main_files:
name = main.split(os.sep)[-1][:-4] name = main.split(os.sep)[-1][:-4]
sources = [main] sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
...@@ -43,11 +44,12 @@ def get_extensions(): ...@@ -43,11 +44,12 @@ def get_extensions():
sources += [path] sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path): if suffix == 'cuda' and osp.exists(path):
sources += [path] sources += [path]
Extension = CppExtension if suffix == 'cpu' else CUDAExtension
extension = Extension( extension = Extension(
'torch_spline_conv._' + name, f'torch_spline_conv._{name}_{suffix}',
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
...@@ -84,8 +86,7 @@ setup( ...@@ -84,8 +86,7 @@ setup(
tests_require=tests_require, tests_require=tests_require,
ext_modules=get_extensions() if not BUILD_DOCS else [], ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={ cmdclass={
'build_ext': 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
}, },
packages=find_packages(), packages=find_packages(),
) )
...@@ -5,9 +5,11 @@ import torch ...@@ -5,9 +5,11 @@ import torch
__version__ = '1.2.0' __version__ = '1.2.0'
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
for library in ['_version', '_basis', '_weighting']: for library in ['_version', '_basis', '_weighting']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
if torch.version.cuda is not None: # pragma: no cover if torch.version.cuda is not None: # pragma: no cover
cuda_version = torch.ops.torch_spline_conv.cuda_version() cuda_version = torch.ops.torch_spline_conv.cuda_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment