"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1c976222d54162a03fa75be8a35be3d644b87bae"
Commit 19df6430 authored by Ken Leidal's avatar Ken Leidal
Browse files

build both cpu and gpu binaries so same package can run on both CPU and GPU machines

parent 981731f0
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = CUDA_HOME is not None
if os.getenv('FORCE_CUDA', '0') == '1': if os.getenv('FORCE_CUDA', '0') == '1':
WITH_CUDA = True WITH_CUDA = True
if os.getenv('FORCE_CPU', '0') == '1': if os.getenv('FORCE_CPU', '0') == '1':
...@@ -17,42 +17,48 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' ...@@ -17,42 +17,48 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions(): def get_extensions():
Extension = CppExtension extensions = []
define_macros = [] for with_cuda, supername in [
extra_compile_args = {'cxx': []} (False, "cpu"),
(True, "gpu"),
]:
if with_cuda and not WITH_CUDA:
continue
Extension = CppExtension
define_macros = []
extra_compile_args = {'cxx': []}
if WITH_CUDA: if with_cuda:
Extension = CUDAExtension Extension = CUDAExtension
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '') nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr'] nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr']
extra_compile_args['nvcc'] = nvcc_flags extra_compile_args['nvcc'] = nvcc_flags
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc') extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
extensions = [] for main in main_files:
for main in main_files: name = main.split(os.sep)[-1][:-4]
name = main.split(os.sep)[-1][:-4]
sources = [main] sources = [main]
path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp')
if osp.exists(path): if osp.exists(path):
sources += [path] sources += [path]
path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu')
if WITH_CUDA and osp.exists(path): if with_cuda and osp.exists(path):
sources += [path] sources += [path]
extension = Extension( extension = Extension(
'torch_scatter._' + name, 'torch_scatter._%s_%s' % (name, supername),
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
extra_compile_args=extra_compile_args, extra_compile_args=extra_compile_args,
) )
extensions += [extension] extensions += [extension]
return extensions return extensions
......
...@@ -6,8 +6,14 @@ import torch ...@@ -6,8 +6,14 @@ import torch
__version__ = '2.0.5' __version__ = '2.0.5'
if torch.cuda.is_available():
sublib = "gpu"
else:
sublib = "cpu"
try: try:
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
library = "%s_%s" % (library, sublib)
torch.ops.load_library(importlib.machinery.PathFinder().find_spec( torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin) library, [osp.dirname(__file__)]).origin)
except AttributeError as e: except AttributeError as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment