Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
1f6189cd
Commit
1f6189cd
authored
Aug 07, 2018
by
rusty1s
Browse files
version up
parent
4e327acc
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
14 additions
and
107 deletions
+14
-107
MANIFEST.in
MANIFEST.in
+2
-6
build.py
build.py
+0
-40
build.sh
build.sh
+0
-11
setup.py
setup.py
+1
-1
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+1
-1
torch_spline_conv/utils/ext.py
torch_spline_conv/utils/ext.py
+10
-0
torch_spline_conv/utils/ffi.py
torch_spline_conv/utils/ffi.py
+0
-48
No files found.
MANIFEST.in
View file @
1f6189cd
include LICENSE
include build.py
include build.sh
recursive-include aten *
recursive-exclude torch_spline_conv/_ext *
recursive-include cpu *
recursive-include cuda *
build.py
deleted
100644 → 0
View file @
4e327acc
import
os.path
as
osp
import
subprocess
import
torch
from
torch.utils.ffi
import
create_extension
files
=
[
'Basis'
,
'Weighting'
]
headers
=
[
'aten/TH/TH{}.h'
.
format
(
f
)
for
f
in
files
]
sources
=
[
'aten/TH/TH{}.c'
.
format
(
f
)
for
f
in
files
]
include_dirs
=
[
'aten/TH'
]
define_macros
=
[]
extra_objects
=
[]
extra_compile_args
=
[
'-std=c99'
]
with_cuda
=
False
if
torch
.
cuda
.
is_available
():
subprocess
.
call
([
'./build.sh'
,
osp
.
dirname
(
torch
.
__file__
)])
headers
+=
[
'aten/THCC/THCC{}.h'
.
format
(
f
)
for
f
in
files
]
sources
+=
[
'aten/THCC/THCC{}.c'
.
format
(
f
)
for
f
in
files
]
include_dirs
+=
[
'aten/THCC'
]
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
extra_objects
+=
[
'torch_spline_conv/_ext/THC.so'
]
with_cuda
=
True
ffi
=
create_extension
(
name
=
'torch_spline_conv._ext.ffi'
,
package
=
True
,
headers
=
headers
,
sources
=
sources
,
include_dirs
=
include_dirs
,
define_macros
=
define_macros
,
extra_objects
=
extra_objects
,
extra_compile_args
=
extra_compile_args
,
with_cuda
=
with_cuda
,
relative_to
=
__file__
)
if
__name__
==
'__main__'
:
ffi
.
build
()
build.sh
deleted
100755 → 0
View file @
4e327acc
#!/bin/sh
echo
"Compiling kernel..."
if
[
-z
"
$1
"
]
;
then
TORCH
=
$(
python
-c
"import os; import torch; print(os.path.dirname(torch.__file__))"
)
;
else
TORCH
=
"
$1
"
;
fi
SRC_DIR
=
aten/THC
BUILD_DIR
=
torch_spline_conv/_ext
mkdir
-p
"
$BUILD_DIR
"
$(
which nvcc
)
"-I
$TORCH
/lib/include"
"-I
$TORCH
/lib/include/TH"
"-I
$TORCH
/lib/include/THC"
"-I
$SRC_DIR
"
-c
"
$SRC_DIR
/THC.cu"
-o
"
$BUILD_DIR
/THC.so"
--compiler-options
'-fPIC'
-std
=
c++11
setup.py
View file @
1f6189cd
...
...
@@ -2,7 +2,7 @@ from os import path as osp
from
setuptools
import
setup
,
find_packages
__version__
=
'1.0.
3
'
__version__
=
'1.0.
4
'
url
=
'https://github.com/rusty1s/pytorch_spline_conv'
install_requires
=
[
'cffi'
]
...
...
torch_spline_conv/__init__.py
View file @
1f6189cd
...
...
@@ -2,6 +2,6 @@ from .basis import SplineBasis
from
.weighting
import
SplineWeighting
from
.conv
import
SplineConv
__version__
=
'1.0.
3
'
__version__
=
'1.0.
4
'
__all__
=
[
'SplineBasis'
,
'SplineWeighting'
,
'SplineConv'
,
'__version__'
]
torch_spline_conv/utils/ext.py
0 → 100644
View file @
1f6189cd
import
torch
import
spline_conv_cpu
if
torch
.
cuda
.
is_available
():
import
spline_conv_cuda
def
get_func
(
name
,
tensor
):
module
=
spline_conv_cuda
if
tensor
.
is_cuda
else
spline_conv_cpu
return
getattr
(
module
,
name
)
torch_spline_conv/utils/ffi.py
deleted
100644 → 0
View file @
4e327acc
from
.._ext
import
ffi
implemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
def
get_func
(
name
,
tensor
):
prefix
=
'THCC'
if
tensor
.
is_cuda
else
'TH'
prefix
+=
tensor
.
type
().
split
(
'.'
)[
-
1
]
return
getattr
(
ffi
,
'{}_{}'
.
format
(
prefix
,
name
))
def
get_degree_str
(
degree
):
degree
=
implemented_degrees
.
get
(
degree
)
assert
degree
is
not
None
,
(
'No implementation found for specified B-spline degree'
)
return
degree
def
fw_basis
(
degree
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
):
name
=
'{}BasisForward'
.
format
(
get_degree_str
(
degree
))
func
=
get_func
(
name
,
basis
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
)
def
bw_basis
(
degree
,
self
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
name
=
'{}BasisBackward'
.
format
(
get_degree_str
(
degree
))
func
=
get_func
(
name
,
self
)
func
(
self
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
def
fw_weighting
(
self
,
src
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingForward'
,
self
)
func
(
self
,
src
,
weight
,
basis
,
weight_index
)
def
bw_weighting_src
(
self
,
grad_out
,
weight
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardSrc'
,
self
)
func
(
self
,
grad_out
,
weight
,
basis
,
weight_index
)
def
bw_weighting_weight
(
self
,
grad_out
,
src
,
basis
,
weight_index
):
func
=
get_func
(
'weightingBackwardWeight'
,
self
)
func
(
self
,
grad_out
,
src
,
basis
,
weight_index
)
def
bw_weighting_basis
(
self
,
grad_out
,
src
,
weight
,
weight_index
):
func
=
get_func
(
'weightingBackwardBasis'
,
self
)
func
(
self
,
grad_out
,
src
,
weight
,
weight_index
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment