Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
69d73030
"vscode:/vscode.git/clone" did not exist on "1ad65879a1af25941abc13a43566ebdf92073e6c"
Commit
69d73030
authored
Mar 12, 2018
by
rusty1s
Browse files
basis backward boilerplate
parent
d7a83c01
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
56 additions
and
24 deletions
+56
-24
test/test_basis.py
test/test_basis.py
+1
-0
torch_spline_conv/functions/ffi.py
torch_spline_conv/functions/ffi.py
+17
-7
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+1
-0
torch_spline_conv/functions/spline_weighting.py
torch_spline_conv/functions/spline_weighting.py
+13
-7
torch_spline_conv/src/cpu.h
torch_spline_conv/src/cpu.h
+15
-10
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+9
-0
No files found.
test/test_basis.py
View file @
69d73030
...
@@ -17,6 +17,7 @@ f.close()
...
@@ -17,6 +17,7 @@ f.close()
def
test_spline_basis_cpu
(
tensor
,
i
):
def
test_spline_basis_cpu
(
tensor
,
i
):
degree
=
data
[
i
].
get
(
'degree'
)
degree
=
data
[
i
].
get
(
'degree'
)
pseudo
=
Tensor
(
tensor
,
data
[
i
][
'pseudo'
])
pseudo
=
Tensor
(
tensor
,
data
[
i
][
'pseudo'
])
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
kernel_size
=
torch
.
LongTensor
(
data
[
i
][
'kernel_size'
])
kernel_size
=
torch
.
LongTensor
(
data
[
i
][
'kernel_size'
])
is_open_spline
=
torch
.
ByteTensor
(
data
[
i
][
'is_open_spline'
])
is_open_spline
=
torch
.
ByteTensor
(
data
[
i
][
'is_open_spline'
])
K
=
kernel_size
.
prod
()
K
=
kernel_size
.
prod
()
...
...
torch_spline_conv/functions/ffi.py
View file @
69d73030
...
@@ -3,6 +3,13 @@ from .._ext import ffi as ext
...
@@ -3,6 +3,13 @@ from .._ext import ffi as ext
implemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
implemented_degrees
=
{
1
:
'linear'
,
2
:
'quadratic'
,
3
:
'cubic'
}
def
get_degree_str
(
degree
):
degree
=
implemented_degrees
.
get
(
degree
)
assert
degree
is
not
None
,
(
'No implementation found for specified B-spline degree'
)
return
degree
def
get_func
(
name
,
tensor
):
def
get_func
(
name
,
tensor
):
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
tensor
.
is_cuda
else
''
cuda
=
'cuda_'
if
tensor
.
is_cuda
else
''
...
@@ -12,19 +19,22 @@ def get_func(name, tensor):
...
@@ -12,19 +19,22 @@ def get_func(name, tensor):
def
spline_basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
K
):
def
spline_basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
K
):
s
=
(
degree
+
1
)
**
kernel_size
.
size
(
0
)
s
=
(
degree
+
1
)
**
kernel_size
.
size
(
0
)
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
basis
=
pseudo
.
new
(
pseudo
.
size
(
0
),
s
)
basis
=
pseudo
.
new
(
pseudo
.
size
(
0
),
s
)
weight_index
=
kernel_size
.
new
(
pseudo
.
size
(
0
),
s
)
weight_index
=
kernel_size
.
new
(
pseudo
.
size
(
0
),
s
)
func
=
get_func
(
'{}_basis_forward'
.
format
(
get_degree_str
(
degree
)),
pseudo
)
degree
=
implemented_degrees
.
get
(
degree
)
assert
degree
is
not
None
,
(
'Basis computation not implemented for specified B-spline degree'
)
func
=
get_func
(
'{}_basis_forward'
.
format
(
degree
),
pseudo
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
return
basis
,
weight_index
return
basis
,
weight_index
# pragma: no cover
def
spline_basis_backward
(
degree
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
grad_pseudo
=
pseudo
.
new
(
pseudo
.
size
())
func
=
get_func
(
'{}_basis_backward'
.
format
(
get_degree_str
(
degree
)),
pseudo
)
func
(
grad_pseudo
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
return
grad_pseudo
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
func
=
get_func
(
'weighting_forward'
,
x
)
func
=
get_func
(
'weighting_forward'
,
x
)
...
...
torch_spline_conv/functions/spline_conv.py
View file @
69d73030
...
@@ -42,6 +42,7 @@ def basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
...
@@ -42,6 +42,7 @@ def basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
n
,
e
,
m_out
=
x
.
size
(
0
),
edge_index
.
size
(
1
),
weight
.
size
(
2
)
n
,
e
,
m_out
=
x
.
size
(
0
),
edge_index
.
size
(
1
),
weight
.
size
(
2
)
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
x
=
x
.
unsqueeze
(
-
1
)
if
x
.
dim
()
==
1
else
x
pseudo
=
pseudo
.
unsqueeze
(
-
1
)
if
pseudo
.
dim
()
==
1
else
pseudo
# Weight gathered features based on B-spline bases and trainable weights.
# Weight gathered features based on B-spline bases and trainable weights.
output
=
spline_weighting
(
x
[
edge_index
[
1
]],
pseudo
,
weight
,
kernel_size
,
output
=
spline_weighting
(
x
[
edge_index
[
1
]],
pseudo
,
weight
,
kernel_size
,
...
...
torch_spline_conv/functions/spline_weighting.py
View file @
69d73030
import
torch
import
torch
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.ffi
import
(
spline_basis_forward
,
spline_weighting_forward
,
from
.ffi
import
(
spline_basis_forward
,
spline_basis_backward
,
spline_weighting_forward
,
spline_weighting_backward_input
,
spline_weighting_backward_input
,
spline_weighting_backward_basis
,
spline_weighting_backward_basis
,
spline_weighting_backward_weight
)
spline_weighting_backward_weight
,
)
class
SplineWeighting
(
Function
):
class
SplineWeighting
(
Function
):
...
@@ -20,13 +24,13 @@ class SplineWeighting(Function):
...
@@ -20,13 +24,13 @@ class SplineWeighting(Function):
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
,
K
)
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
,
K
)
output
=
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
output
=
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
)
self
.
save_for_backward
(
x
,
weight
)
self
.
save_for_backward
(
x
,
pseudo
,
weight
)
self
.
basis
,
self
.
weight_index
=
basis
,
weight_index
self
.
basis
,
self
.
weight_index
=
basis
,
weight_index
return
output
return
output
def
backward
(
self
,
grad_output
):
# pragma: no cover
def
backward
(
self
,
grad_output
):
# pragma: no cover
x
,
weight
=
self
.
saved_tensors
x
,
pseudo
,
weight
=
self
.
saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
grad_input
,
grad_pseudo
,
grad_weight
=
None
,
None
,
None
grad_input
,
grad_pseudo
,
grad_weight
=
None
,
None
,
None
...
@@ -37,7 +41,9 @@ class SplineWeighting(Function):
...
@@ -37,7 +41,9 @@ class SplineWeighting(Function):
if
self
.
needs_input_grad
[
1
]:
if
self
.
needs_input_grad
[
1
]:
grad_basis
=
spline_weighting_backward_basis
(
grad_basis
=
spline_weighting_backward_basis
(
grad_output
,
x
,
weight
,
weight_index
)
grad_output
,
x
,
weight
,
weight_index
)
print
(
'pseudo needs grad'
)
grad_pseudo
=
spline_basis_backward
(
self
.
degree
,
grad_basis
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
)
if
self
.
needs_input_grad
[
2
]:
if
self
.
needs_input_grad
[
2
]:
K
=
weight
.
size
(
0
)
K
=
weight
.
size
(
0
)
...
...
torch_spline_conv/src/cpu.h
View file @
69d73030
void
spline_linear_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_linear_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_linear_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_linear_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_quadratic_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_quadratic_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_quadratic_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_quadratic_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_cubic_basis_forward_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_weighting_forward_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_linear_basis_backward_Float
(
THFloatTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_linear_basis_backward_Double
(
THDoubleTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_quadratic_basis_backward_Float
(
THFloatTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_quadratic_basis_backward_Double
(
THDoubleTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_cubic_basis_backward_Float
(
THFloatTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_cubic_basis_backward_Double
(
THDoubleTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
);
void
spline_weighting_forward_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_forward_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_forward_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_input_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Float
(
THFloatTensor
*
grad_basis
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Float
(
THFloatTensor
*
grad_basis
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Double
(
THDoubleTensor
*
grad_basis
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_basis_Double
(
THDoubleTensor
*
grad_basis
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Float
(
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Float
(
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Double
(
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_backward_weight_Double
(
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cpu.c
View file @
69d73030
...
@@ -25,6 +25,15 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
...
@@ -25,6 +25,15 @@ void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, T
)
)
}
}
void
spline_
(
linear_basis_backward
)(
THTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
}
void
spline_
(
quadratic_basis_backward
)(
THTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
}
void
spline_
(
cubic_basis_backward
)(
THTensor
*
grad_pseudo
,
THLongTensor
*
grad_basis
,
THTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
)
{
}
void
spline_
(
weighting_forward
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_forward
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
int64_t
M_out
=
THTensor_
(
size
)(
output
,
1
);
int64_t
M_out
=
THTensor_
(
size
)(
output
,
1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment