Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
3fbeeabc
Commit
3fbeeabc
authored
Mar 10, 2018
by
rusty1s
Browse files
added backward
parent
5487e31a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
174 additions
and
14 deletions
+174
-14
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+8
-7
torch_spline_conv/src/THTensorDimApply.h
torch_spline_conv/src/THTensorDimApply.h
+157
-0
torch_spline_conv/src/cpu.c
torch_spline_conv/src/cpu.c
+1
-1
torch_spline_conv/src/cpu.h
torch_spline_conv/src/cpu.h
+4
-4
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+4
-2
No files found.
torch_spline_conv/functions/utils.py
View file @
3fbeeabc
...
...
@@ -28,17 +28,17 @@ def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
return
basis
,
weight_index
def
spline_weighting_f
w
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_f
orward
(
x
,
weight
,
basis
,
weight_index
):
output
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
2
))
func
=
get_func
(
'weighting_f
w
'
,
x
)
func
=
get_func
(
'weighting_f
orward
'
,
x
)
func
(
output
,
x
,
weight
,
basis
,
weight_index
)
return
output
def
spline_weighting_b
w
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_b
ackward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
):
grad_input
=
x
.
new
(
x
.
size
(
0
),
weight
.
size
(
1
))
grad_weight
=
x
.
new
(
weight
)
func
=
get_func
(
'weighting_b
w
'
,
x
)
func
=
get_func
(
'weighting_b
ackward
'
,
x
)
func
(
grad_input
,
grad_weight
,
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
grad_input
,
grad_weight
...
...
@@ -52,16 +52,17 @@ class SplineWeighting(Function):
def
forward
(
self
,
x
,
weight
):
self
.
save_for_backward
(
x
,
weight
)
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_f
w
(
x
,
weight
,
basis
,
weight_index
)
return
spline_weighting_f
orward
(
x
,
weight
,
basis
,
weight_index
)
def
backward
(
self
,
grad_output
):
x
,
weight
=
self
.
saved_tensors
basis
,
weight_index
=
self
.
basis
,
self
.
weight_index
return
spline_weighting_bw
(
grad_output
,
x
,
weight
,
basis
,
weight_index
)
return
spline_weighting_backward
(
grad_output
,
x
,
weight
,
basis
,
weight_index
)
def
spline_weighting
(
x
,
weight
,
basis
,
weight_index
):
if
torch
.
is_tensor
(
x
):
return
spline_weighting_f
w
(
x
,
weight
,
basis
,
weight_index
)
return
spline_weighting_f
orward
(
x
,
weight
,
basis
,
weight_index
)
else
:
return
SplineWeighting
(
basis
,
weight_index
)(
x
,
weight
)
torch_spline_conv/src/THTensorDimApply
4
.h
→
torch_spline_conv/src/THTensorDimApply.h
View file @
3fbeeabc
...
...
@@ -72,3 +72,86 @@
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
#define TH_TENSOR_DIM_APPLY5(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, TENSOR4, TYPE5, TENSOR5, DIMENSION, CODE) { \
TYPE1 *TENSOR1##_data = NULL; \
int64_t TENSOR1##_stride = 0, TENSOR1##_size = 0; \
TYPE2 *TENSOR2##_data = NULL; \
int64_t TENSOR2##_stride = 0, TENSOR2##_size = 0; \
TYPE3 *TENSOR3##_data = NULL; \
int64_t TENSOR3##_stride = 0, TENSOR3##_size = 0; \
TYPE4 *TENSOR4##_data = NULL; \
int64_t TENSOR4##_stride = 0, TENSOR4##_size = 0; \
TYPE5 *TENSOR5##_data = NULL; \
int64_t TENSOR5##_stride = 0, TENSOR5##_size = 0; \
\
int64_t *TH_TENSOR_DIM_APPLY_counter = NULL; \
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
int TH_TENSOR_DIM_APPLY_i; \
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
\
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
\
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
\
TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \
TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \
TENSOR3##_size = TENSOR3->size[DIMENSION]; \
\
TENSOR4##_data = (TENSOR4)->storage->data+(TENSOR4)->storageOffset; \
TENSOR4##_stride = (TENSOR4)->stride[DIMENSION]; \
TENSOR4##_size = TENSOR4->size[DIMENSION]; \
\
TENSOR5##_data = (TENSOR5)->storage->data+(TENSOR5)->storageOffset; \
TENSOR5##_stride = (TENSOR5)->stride[DIMENSION]; \
TENSOR5##_size = TENSOR5->size[DIMENSION]; \
\
while (!TH_TENSOR_DIM_APPLY_hasFinished) { \
CODE \
\
if (TENSOR1->nDimension == 1) break; \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
if (TH_TENSOR_DIM_APPLY_i == DIMENSION) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
continue; \
} \
\
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data += TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data += TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
\
if (TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
else { \
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR5##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR5->stride[TH_TENSOR_DIM_APPLY_i]; \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
} \
else break; \
} \
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
torch_spline_conv/src/cpu.c
View file @
3fbeeabc
#include <TH/TH.h>
#include "THTensorDimApply
4
.h"
#include "THTensorDimApply.h"
#define spline_(NAME) TH_CONCAT_4(spline_, NAME, _, Real)
...
...
torch_spline_conv/src/cpu.h
View file @
3fbeeabc
...
...
@@ -7,8 +7,8 @@ void spline_basis_quadratic_Double(THDoubleTensor *basis, THLongTensor *weight_i
void
spline_basis_cubic_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_basis_cubic_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_weighting_f
w
_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_f
w
_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_f
orward
_Float
(
THFloatTensor
*
output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_f
orward
_Double
(
THDoubleTensor
*
output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_b
w
_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_b
w
_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_b
ackward
_Float
(
THFloatTensor
*
grad_input
,
THFloatTensor
*
grad_weight
,
THFloatTensor
*
grad_output
,
THFloatTensor
*
input
,
THFloatTensor
*
weight
,
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
);
void
spline_weighting_b
ackward
_Double
(
THDoubleTensor
*
grad_input
,
THDoubleTensor
*
grad_weight
,
THDoubleTensor
*
grad_output
,
THDoubleTensor
*
input
,
THDoubleTensor
*
weight
,
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
);
torch_spline_conv/src/generic/cpu.c
View file @
3fbeeabc
...
...
@@ -50,7 +50,7 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
)
}
void
spline_
(
weighting_f
w
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_f
orward
)(
THTensor
*
output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
int64_t
M_out
=
THTensor_
(
size
)(
output
,
1
);
int64_t
M_in
=
THTensor_
(
size
)(
input
,
1
);
...
...
@@ -72,7 +72,9 @@ void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight,
)
}
void
spline_
(
weighting_bw
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
void
spline_
(
weighting_backward
)(
THTensor
*
grad_input
,
THTensor
*
grad_weight
,
THTensor
*
grad_output
,
THTensor
*
input
,
THTensor
*
weight
,
THTensor
*
basis
,
THLongTensor
*
weight_index
)
{
TH_TENSOR_DIM_APPLY5
(
real
,
grad_input
,
real
,
grad_output
,
real
,
input
,
real
,
basis
,
int64_t
,
weight_index
,
1
,
)
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment