Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
04ae443a
Commit
04ae443a
authored
Jan 24, 2019
by
rusty1s
Browse files
year up, restricted coverage, nested extensions
parent
cc0a7284
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
23 additions
and
17 deletions
+23
-17
.coveragerc
.coveragerc
+2
-0
LICENSE
LICENSE
+1
-1
cuda/basis.cpp
cuda/basis.cpp
+1
-1
cuda/weighting.cpp
cuda/weighting.cpp
+1
-1
setup.py
setup.py
+6
-6
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+6
-4
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+6
-4
No files found.
.coveragerc
View file @
04ae443a
[run]
source=torch_spline_conv
[report]
exclude_lines =
pragma: no cover
...
...
LICENSE
View file @
04ae443a
Copyright (c) 201
8
Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 201
9
Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
...
...
cuda/basis.cpp
View file @
04ae443a
#include <torch/
torch
.h>
#include <torch/
extension
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
...
...
cuda/weighting.cpp
View file @
04ae443a
#include <torch/
torch
.h>
#include <torch/
extension
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
...
...
setup.py
View file @
04ae443a
...
...
@@ -3,16 +3,16 @@ import torch
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
ext_modules
=
[
CppExtension
(
'basis_cpu'
,
[
'cpu/basis.cpp'
]),
CppExtension
(
'weighting_cpu'
,
[
'cpu/weighting.cpp'
]),
CppExtension
(
'
torch_spline_conv.
basis_cpu'
,
[
'cpu/basis.cpp'
]),
CppExtension
(
'
torch_spline_conv.
weighting_cpu'
,
[
'cpu/weighting.cpp'
]),
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
if
CUDA_HOME
is
not
None
:
ext_modules
+=
[
CUDAExtension
(
'basis_cuda'
,
CUDAExtension
(
'
torch_spline_conv.
basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
]),
CUDAExtension
(
'weighting_cuda'
,
CUDAExtension
(
'
torch_spline_conv.
weighting_cuda'
,
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
]),
]
...
...
@@ -26,8 +26,8 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_spline_conv'
,
version
=
__version__
,
description
=
'Implementation of the Spline-Based Convolution'
'Operator of
SplineCNN in PyTorch'
,
description
=
(
'Implementation of the Spline-Based Convolution
Operator of
'
'
SplineCNN in PyTorch'
)
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
url
,
...
...
torch_spline_conv/basis.py
View file @
04ae443a
import
torch
import
basis_cpu
import
torch_spline_conv.
basis_cpu
if
torch
.
cuda
.
is_available
():
import
basis_cuda
import
torch_spline_conv.
basis_cuda
implemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
def
get_func
(
name
,
tensor
):
module
=
basis_cuda
if
tensor
.
is_cuda
else
basis_cpu
return
getattr
(
module
,
name
)
if
tensor
.
is_cuda
:
return
getattr
(
torch_spline_conv
.
basis_cuda
,
name
)
else
:
return
getattr
(
torch_spline_conv
.
basis_cpu
,
name
)
class
SplineBasis
(
torch
.
autograd
.
Function
):
...
...
torch_spline_conv/weighting.py
View file @
04ae443a
import
torch
import
weighting_cpu
import
torch_spline_conv.
weighting_cpu
if
torch
.
cuda
.
is_available
():
import
weighting_cuda
import
torch_spline_conv.
weighting_cuda
def
get_func
(
name
,
tensor
):
module
=
weighting_cuda
if
tensor
.
is_cuda
else
weighting_cpu
return
getattr
(
module
,
name
)
if
tensor
.
is_cuda
:
return
getattr
(
torch_spline_conv
.
weighting_cuda
,
name
)
else
:
return
getattr
(
torch_spline_conv
.
weighting_cpu
,
name
)
class
SplineWeighting
(
torch
.
autograd
.
Function
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment