Commit 1f6189cd authored by rusty1s's avatar rusty1s
Browse files

version up

parent 4e327acc
include LICENSE
include build.py
include build.sh
recursive-include aten *
recursive-exclude torch_spline_conv/_ext *
recursive-include cpu *
recursive-include cuda *
import os.path as osp
import subprocess
import torch
from torch.utils.ffi import create_extension
files = ['Basis', 'Weighting']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
sources = ['aten/TH/TH{}.c'.format(f) for f in files]
include_dirs = ['aten/TH']
define_macros = []
extra_objects = []
extra_compile_args = ['-std=c99']
with_cuda = False
if torch.cuda.is_available():
subprocess.call(['./build.sh', osp.dirname(torch.__file__)])
headers += ['aten/THCC/THCC{}.h'.format(f) for f in files]
sources += ['aten/THCC/THCC{}.c'.format(f) for f in files]
include_dirs += ['aten/THCC']
define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_spline_conv/_ext/THC.so']
with_cuda = True
ffi = create_extension(
name='torch_spline_conv._ext.ffi',
package=True,
headers=headers,
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args,
with_cuda=with_cuda,
relative_to=__file__)
if __name__ == '__main__':
ffi.build()
#!/bin/sh
echo "Compiling kernel..."
if [ -z "$1" ]; then TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))"); else TORCH="$1"; fi
SRC_DIR=aten/THC
BUILD_DIR=torch_spline_conv/_ext
mkdir -p "$BUILD_DIR"
$(which nvcc) "-I$TORCH/lib/include" "-I$TORCH/lib/include/TH" "-I$TORCH/lib/include/THC" "-I$SRC_DIR" -c "$SRC_DIR/THC.cu" -o "$BUILD_DIR/THC.so" --compiler-options '-fPIC' -std=c++11
......@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages
__version__ = '1.0.3'
__version__ = '1.0.4'
url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = ['cffi']
......
......@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting
from .conv import SplineConv
__version__ = '1.0.3'
__version__ = '1.0.4'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
import torch
import spline_conv_cpu
if torch.cuda.is_available():
import spline_conv_cuda
def get_func(name, tensor):
module = spline_conv_cuda if tensor.is_cuda else spline_conv_cpu
return getattr(module, name)
from .._ext import ffi
implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def get_func(name, tensor):
prefix = 'THCC' if tensor.is_cuda else 'TH'
prefix += tensor.type().split('.')[-1]
return getattr(ffi, '{}_{}'.format(prefix, name))
def get_degree_str(degree):
degree = implemented_degrees.get(degree)
assert degree is not None, (
'No implementation found for specified B-spline degree')
return degree
def fw_basis(degree, basis, weight_index, pseudo, kernel_size, is_open_spline):
name = '{}BasisForward'.format(get_degree_str(degree))
func = get_func(name, basis)
func(basis, weight_index, pseudo, kernel_size, is_open_spline)
def bw_basis(degree, self, grad_basis, pseudo, kernel_size, is_open_spline):
name = '{}BasisBackward'.format(get_degree_str(degree))
func = get_func(name, self)
func(self, grad_basis, pseudo, kernel_size, is_open_spline)
def fw_weighting(self, src, weight, basis, weight_index):
func = get_func('weightingForward', self)
func(self, src, weight, basis, weight_index)
def bw_weighting_src(self, grad_out, weight, basis, weight_index):
func = get_func('weightingBackwardSrc', self)
func(self, grad_out, weight, basis, weight_index)
def bw_weighting_weight(self, grad_out, src, basis, weight_index):
func = get_func('weightingBackwardWeight', self)
func(self, grad_out, src, basis, weight_index)
def bw_weighting_basis(self, grad_out, src, weight, weight_index):
func = get_func('weightingBackwardBasis', self)
func(self, grad_out, src, weight, weight_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment