Commit 04ae443a authored by rusty1s's avatar rusty1s
Browse files

year up, restricted coverage, nested extensions

parent cc0a7284
[run]
source=torch_spline_conv
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
......
Copyright (c) 2018 Matthias Fey <matthias.fey@tu-dortmund.de> Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
#include <torch/torch.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
#include <torch/torch.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
...@@ -3,16 +3,16 @@ import torch ...@@ -3,16 +3,16 @@ import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [ ext_modules = [
CppExtension('basis_cpu', ['cpu/basis.cpp']), CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp']),
CppExtension('weighting_cpu', ['cpu/weighting.cpp']), CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp']),
] ]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None: if CUDA_HOME is not None:
ext_modules += [ ext_modules += [
CUDAExtension('basis_cuda', CUDAExtension('torch_spline_conv.basis_cuda',
['cuda/basis.cpp', 'cuda/basis_kernel.cu']), ['cuda/basis.cpp', 'cuda/basis_kernel.cu']),
CUDAExtension('weighting_cuda', CUDAExtension('torch_spline_conv.weighting_cuda',
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']), ['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
] ]
...@@ -26,8 +26,8 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -26,8 +26,8 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_spline_conv', name='torch_spline_conv',
version=__version__, version=__version__,
description='Implementation of the Spline-Based Convolution' description=('Implementation of the Spline-Based Convolution Operator of'
'Operator of SplineCNN in PyTorch', 'SplineCNN in PyTorch'),
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url=url, url=url,
......
import torch import torch
import basis_cpu import torch_spline_conv.basis_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
import basis_cuda import torch_spline_conv.basis_cuda
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'} implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_func(name, tensor): def get_func(name, tensor):
module = basis_cuda if tensor.is_cuda else basis_cpu if tensor.is_cuda:
return getattr(module, name) return getattr(torch_spline_conv.basis_cuda, name)
else:
return getattr(torch_spline_conv.basis_cpu, name)
class SplineBasis(torch.autograd.Function): class SplineBasis(torch.autograd.Function):
......
import torch import torch
import weighting_cpu import torch_spline_conv.weighting_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
import weighting_cuda import torch_spline_conv.weighting_cuda
def get_func(name, tensor): def get_func(name, tensor):
module = weighting_cuda if tensor.is_cuda else weighting_cpu if tensor.is_cuda:
return getattr(module, name) return getattr(torch_spline_conv.weighting_cuda, name)
else:
return getattr(torch_spline_conv.weighting_cpu, name)
class SplineWeighting(torch.autograd.Function): class SplineWeighting(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment