Commit d305ecc0 authored by rusty1s's avatar rusty1s
Browse files

nested extensions

parent 15afee0d
...@@ -4,14 +4,14 @@ from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME ...@@ -4,14 +4,14 @@ from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [ ext_modules = [
CppExtension( CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'], 'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable']) extra_compile_args=['-Wno-unused-variable'])
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None: if CUDA_HOME is not None:
ext_modules += [ ext_modules += [
CUDAExtension('scatter_cuda', CUDAExtension('torch_scatter.scatter_cuda',
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu']) ['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
] ]
......
import torch import torch
import scatter_cpu import torch_scatter.scatter_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
import scatter_cuda import torch_scatter.scatter_cuda
def get_func(name, tensor): def get_func(name, tensor):
module = scatter_cuda if tensor.is_cuda else scatter_cpu if tensor.is_cuda:
module = torch_scatter.scatter_cuda
else:
module = torch_scatter.scatter_cpu
return getattr(module, name) return getattr(module, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment