Commit 04ae443a authored by rusty1s's avatar rusty1s
Browse files

year up, restricted coverage, nested extensions

parent cc0a7284
[run]
source=torch_spline_conv
[report]
exclude_lines =
pragma: no cover
......
Copyright (c) 2018 Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
......
#include <torch/torch.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
#include <torch/torch.h>
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
......
......@@ -3,16 +3,16 @@ import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
ext_modules = [
CppExtension('basis_cpu', ['cpu/basis.cpp']),
CppExtension('weighting_cpu', ['cpu/weighting.cpp']),
CppExtension('torch_spline_conv.basis_cpu', ['cpu/basis.cpp']),
CppExtension('torch_spline_conv.weighting_cpu', ['cpu/weighting.cpp']),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if CUDA_HOME is not None:
ext_modules += [
CUDAExtension('basis_cuda',
CUDAExtension('torch_spline_conv.basis_cuda',
['cuda/basis.cpp', 'cuda/basis_kernel.cu']),
CUDAExtension('weighting_cuda',
CUDAExtension('torch_spline_conv.weighting_cuda',
['cuda/weighting.cpp', 'cuda/weighting_kernel.cu']),
]
......@@ -26,8 +26,8 @@ tests_require = ['pytest', 'pytest-cov']
setup(
name='torch_spline_conv',
version=__version__,
description='Implementation of the Spline-Based Convolution'
'Operator of SplineCNN in PyTorch',
description=('Implementation of the Spline-Based Convolution Operator of'
'SplineCNN in PyTorch'),
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url=url,
......
import torch
import basis_cpu
import torch_spline_conv.basis_cpu
if torch.cuda.is_available():
import basis_cuda
import torch_spline_conv.basis_cuda
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_func(name, tensor):
module = basis_cuda if tensor.is_cuda else basis_cpu
return getattr(module, name)
if tensor.is_cuda:
return getattr(torch_spline_conv.basis_cuda, name)
else:
return getattr(torch_spline_conv.basis_cpu, name)
class SplineBasis(torch.autograd.Function):
......
import torch
import weighting_cpu
import torch_spline_conv.weighting_cpu
if torch.cuda.is_available():
import weighting_cuda
import torch_spline_conv.weighting_cuda
def get_func(name, tensor):
module = weighting_cuda if tensor.is_cuda else weighting_cpu
return getattr(module, name)
if tensor.is_cuda:
return getattr(torch_spline_conv.weighting_cuda, name)
else:
return getattr(torch_spline_conv.weighting_cpu, name)
class SplineWeighting(torch.autograd.Function):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment