Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
57f5a26e
"vscode:/vscode.git/clone" did not exist on "56e707bfccb62ada836d21e431d6db0d10dd73a1"
Commit
57f5a26e
authored
Aug 12, 2018
by
rusty1s
Browse files
coverage fix
parent
cba43ac9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
17 deletions
+10
-17
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+10
-17
No files found.
torch_spline_conv/basis.py
View file @
57f5a26e
...
@@ -12,34 +12,27 @@ def get_func(name, tensor):
...
@@ -12,34 +12,27 @@ def get_func(name, tensor):
return
getattr
(
module
,
name
)
return
getattr
(
module
,
name
)
def
fw
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
op
=
get_func
(
'{}_fw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
basis
,
weight_index
=
op
(
pseudo
,
kernel_size
,
is_open_spline
)
return
basis
,
weight_index
def
bw
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
op
=
get_func
(
'{}_bw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
grad_pseudo
=
op
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
return
grad_pseudo
class
SplineBasis
(
torch
.
autograd
.
Function
):
class
SplineBasis
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
def
forward
(
ctx
,
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
ctx
.
save_for_backward
(
pseudo
)
ctx
.
save_for_backward
(
pseudo
)
ctx
.
kernel_size
=
kernel_size
ctx
.
kernel_size
,
ctx
.
is_open_spline
=
kernel_size
,
is_open_spline
ctx
.
is_open_spline
=
is_open_spline
ctx
.
degree
=
degree
ctx
.
degree
=
degree
return
fw
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
)
op
=
get_func
(
'{}_fw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
basis
,
weight_index
=
op
(
pseudo
,
kernel_size
,
is_open_spline
)
return
basis
,
weight_index
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_basis
,
grad_weight_index
):
def
backward
(
ctx
,
grad_basis
,
grad_weight_index
):
pseudo
,
=
ctx
.
saved_tensors
pseudo
,
=
ctx
.
saved_tensors
kernel_size
,
is_open_spline
=
ctx
.
kernel_size
,
ctx
.
is_open_spline
degree
=
ctx
.
degree
grad_pseudo
=
None
grad_pseudo
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_pseudo
=
bw
(
grad_basis
,
pseudo
,
ctx
.
kernel_size
,
op
=
get_func
(
'{}_bw'
.
format
(
implemented_degrees
[
degree
]),
pseudo
)
ctx
.
is_open_spline
,
ctx
.
degree
)
grad_pseudo
=
op
(
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
return
grad_pseudo
,
None
,
None
,
None
return
grad_pseudo
,
None
,
None
,
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment