"tools/python/vscode:/vscode.git/clone" did not exist on "89ffae61263e2110116282737a221a3cc33aa39f"
Commit d305ecc0 authored by rusty1s's avatar rusty1s
Browse files

nested extensions

parent 15afee0d
......@@ -4,14 +4,14 @@ from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [
CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'],
'torch_scatter.scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable'])
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None:
ext_modules += [
CUDAExtension('scatter_cuda',
CUDAExtension('torch_scatter.scatter_cuda',
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]
......
import torch
import scatter_cpu
import torch_scatter.scatter_cpu
if torch.cuda.is_available():
import scatter_cuda
import torch_scatter.scatter_cuda
def get_func(name, tensor):
module = scatter_cuda if tensor.is_cuda else scatter_cpu
if tensor.is_cuda:
module = torch_scatter.scatter_cuda
else:
module = torch_scatter.scatter_cpu
return getattr(module, name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment