Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
3a07cc5e
Commit
3a07cc5e
authored
Mar 09, 2018
by
rusty1s
Browse files
added python autograd function
parent
eb8f32d7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
13 deletions
+24
-13
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+18
-7
torch_spline_conv/src/cpu.h
torch_spline_conv/src/cpu.h
+4
-4
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+2
-2
No files found.
torch_spline_conv/functions/utils.py
View file @
3a07cc5e
...
@@ -29,12 +29,19 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
...
@@ -29,12 +29,19 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
return
basis
,
weight_index
return
basis
,
weight_index
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_fw
(
x
,
weight
,
basis
,
weight_index
):
pass
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
func
=
get_func
(
'spline_weighting_fw'
,
x
)
func
(
output
,
x
,
weight
,
basis
,
weight_index
)
return
output
def
spline_weighting_backward
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_bw
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
pass
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_weight
=
x
.
new
(
weight
)
func
=
get_func
(
'spline_weighting_bw'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
class
SplineWeighting
(
Function
):
class
SplineWeighting
(
Function
):
...
@@ -44,14 +51,18 @@ class SplineWeighting(Function):
...
@@ -44,14 +51,18 @@ class SplineWeighting(Function):
self
.
weight_index
=
weight_index
self
.
weight_index
=
weight_index
def
forward
(
self
,
x
,
weight
):
def
forward
(
self
,
x
,
weight
):
pass
self
.
save_for_backward
(
x
,
weight
)
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_fw
(
x
,
weight
,
basis
,
weight_index
)
def
backward
(
self
,
grad_output
):
def
backward
(
self
,
grad_output
):
pass
x
,
weight
=
self
.
saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_bw
(
grad_output
,
x
,
weight
,
basis
,
weight_index
)
def
spline_weighting
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting
(
x
,
weight
,
basis
,
weight_index
):
if
torch
.
is_tensor
(
x
):
if
torch
.
is_tensor
(
x
):
return
spline_weighting_f
orward
(
x
,
weight
,
basis
,
weight_index
)
return
spline_weighting_f
w
(
x
,
weight
,
basis
,
weight_index
)
else
:
else
:
return
SplineWeighting
(
basis
,
weight_index
)(
x
,
weight
)
return
SplineWeighting
(
basis
,
weight_index
)(
x
,
weight
)
torch_spline_conv/src/cpu.h
View file @
3a07cc5e
...
@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i
...
@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i
void
spline_basis_cubic_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_basis_cubic_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_basis_cubic_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_basis_cubic_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_edgewise_f
orward
_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_edgewise_f
w
_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_edgewise_f
orward
_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_edgewise_f
w
_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_
edgewise_backward
_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_
weighting_bw
_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_
edgewise_backward
_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_
weighting_bw
_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cpu.c
View file @
3a07cc5e
...
@@ -50,10 +50,10 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
...
@@ -50,10 +50,10 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
)
)
}
}
void
spline_
(
edgewise_forward
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_fw
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
}
}
void
spline_
(
edgewise_backward
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_bw
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
}
}
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment