Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
dff93289
Commit
dff93289
authored
Apr 09, 2018
by
rusty1s
Browse files
backward boilerplate
parent
0b777f0d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
0 deletions
+81
-0
aten/TH/THBasis.h
aten/TH/THBasis.h
+7
-0
aten/TH/generic/THBasis.c
aten/TH/generic/THBasis.c
+12
-0
aten/THCC/THCCBasis.h
aten/THCC/THCCBasis.h
+7
-0
aten/THCC/generic/THCCBasis.c
aten/THCC/generic/THCCBasis.c
+15
-0
torch_spline_conv/basis.py
torch_spline_conv/basis.py
+33
-0
torch_spline_conv/utils/ffi.py
torch_spline_conv/utils/ffi.py
+7
-0
No files found.
aten/TH/THBasis.h
View file @
dff93289
...
...
@@ -4,3 +4,10 @@ void THFloatTensor_quadraticBasisForward( THFloatTensor *basis, THLongTensor *w
void
THDoubleTensor_quadraticBasisForward
(
THDoubleTensor
*
basis
,
THLongTensor
*
weightIndex
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THFloatTensor_cubicBasisForward
(
THFloatTensor
*
basis
,
THLongTensor
*
weightIndex
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THDoubleTensor_cubicBasisForward
(
THDoubleTensor
*
basis
,
THLongTensor
*
weightIndex
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THFloatTensor_linearBasisBackward
(
THFloatTensor
*
self
,
THFloatTensor
*
gradBasis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THDoubleTensor_linearBasisBackward
(
THDoubleTensor
*
self
,
THDoubleTensor
*
gradBasis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THFloatTensor_quadraticBasisBackward
(
THFloatTensor
*
self
,
THFloatTensor
*
gradBasis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THDoubleTensor_quadraticBasisBackward
(
THDoubleTensor
*
self
,
THDoubleTensor
*
gradBasis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THFloatTensor_cubicBasisBackward
(
THFloatTensor
*
self
,
THFloatTensor
*
gradBasis
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
void
THDoubleTensor_cubicBasisBackward
(
THDoubleTensor
*
self
,
THDoubleTensor
*
gradBasis
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
);
aten/TH/generic/THBasis.c
View file @
dff93289
...
...
@@ -40,4 +40,16 @@ void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, TH
)
}
void
THTensor_
(
linearBasisBackward
)(
THTensor
*
self
,
THTensor
*
gradBasis
,
THTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
)
{
}
void
THTensor_
(
quadraticBasisBackward
)(
THTensor
*
self
,
THTensor
*
gradBasis
,
THTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
)
{
}
void
THTensor_
(
cubicBasisBackward
)(
THTensor
*
self
,
THTensor
*
gradBasis
,
THTensor
*
pseudo
,
THLongTensor
*
kernelSize
,
THByteTensor
*
isOpenSpline
)
{
}
#endif // TH_GENERIC_FILE
aten/THCC/THCCBasis.h
View file @
dff93289
...
...
@@ -4,3 +4,10 @@ void THCCFloatTensor_quadraticBasisForward( THCudaTensor *basis, THCudaLon
void
THCCDoubleTensor_quadraticBasisForward
(
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCFloatTensor_cubicBasisForward
(
THCudaTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCDoubleTensor_cubicBasisForward
(
THCudaDoubleTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCFloatTensor_linearBasisBackward
(
THCudaTensor
*
self
,
THCudaTensor
*
gradBasis
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCDoubleTensor_linearBasisBackward
(
THCudaDoubleTensor
*
self
,
THCudaDoubleTensor
*
gradBasis
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCFloatTensor_quadraticBasisBackward
(
THCudaTensor
*
self
,
THCudaTensor
*
gradBasis
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCDoubleTensor_quadraticBasisBackward
(
THCudaDoubleTensor
*
self
,
THCudaDoubleTensor
*
gradBasis
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCFloatTensor_cubicBasisBackward
(
THCudaTensor
*
self
,
THCudaTensor
*
gradBasis
,
THCudaTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCCDoubleTensor_cubicBasisBackward
(
THCudaDoubleTensor
*
self
,
THCudaDoubleTensor
*
gradBasis
,
THCudaDoubleTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
aten/THCC/generic/THCCBasis.c
View file @
dff93289
...
...
@@ -20,4 +20,19 @@ void THCCTensor_(cubicBasisForward)(THCTensor *basis, THCudaLongTensor *weightIn
THCTensor_
(
cubicBasisForward
)(
state
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
);
}
void
THCCTensor_
(
linearBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
}
void
THCCTensor_
(
quadraticBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
}
void
THCCTensor_
(
cubicBasisBackward
)(
THCTensor
*
self
,
THCTensor
*
gradBasis
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
)
{
}
#endif // THC_GENERIC_FILE
torch_spline_conv/basis.py
View file @
dff93289
import
torch
from
torch.autograd
import
Function
from
.utils.ffi
import
basis_forward
as
ffi_basis_forward
from
.utils.ffi
import
basis_backward
as
ffi_basis_backward
def
basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
...
...
@@ -9,3 +13,32 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline):
ffi_basis_forward
(
degree
,
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
)
return
basis
,
weight_index
def
basis_backward
(
degree
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
grad_pseudo
=
pseudo
.
new
(
pseudo
.
size
())
ffi_basis_backward
(
degree
,
grad_pseudo
,
pseudo
,
kernel_size
,
is_open_spline
)
class
Basis
(
Function
):
def
__init__
(
self
,
degree
,
kernel_size
,
is_open_spline
):
super
(
Basis
,
self
).
__init__
()
self
.
degree
=
degree
self
.
kernel_size
=
kernel_size
self
.
is_open_spline
=
is_open_spline
def
forward
(
self
,
pseudo
):
self
.
save_for_backawrd
(
pseudo
)
return
basis_forward
(
self
.
degree
,
pseudo
,
self
.
kernel_size
,
self
.
is_open_spline
)
def
backward
(
self
,
grad_basis
,
grad_weight_index
):
pass
def
basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
):
if
torch
.
is_tensor
(
pseudo
):
return
basis_forward
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
)
else
:
return
Basis
(
degree
,
kernel_size
,
is_open_spline
)(
pseudo
)
torch_spline_conv/utils/ffi.py
View file @
dff93289
...
...
@@ -21,3 +21,10 @@ def basis_forward(degree, basis, weight_index, pseudo, kernel_size,
name
=
'{}BasisForward'
.
format
(
get_degree_str
(
degree
))
func
=
get_func
(
name
,
basis
.
is_cuda
,
basis
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
)
def
basis_backward
(
degree
,
self
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
):
name
=
'{}BasisBackward'
.
format
(
get_degree_str
(
degree
))
func
=
get_func
(
name
,
self
.
is_cuda
,
self
)
func
(
self
,
grad_basis
,
pseudo
,
kernel_size
,
is_open_spline
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment