Commit e8e5841d authored by rusty1s's avatar rusty1s
Browse files

added basis grad for all degrees

parent 6100aa77
......@@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv
from torch_spline_conv.functions.spline_weighting import SplineWeighting
# from torch_spline_conv.functions.ffi import implemented_degrees
from torch_spline_conv.functions.ffi import implemented_degrees
from .utils import tensors, Tensor
......@@ -49,7 +49,7 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu():
for degree in [1]:
for degree in implemented_degrees.keys():
kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1])
op = SplineWeighting(kernel_size, is_open_spline, degree)
......
......@@ -42,6 +42,8 @@
quotient = pow(M + 1, d); \
for (s = 0; s < S; s++) { \
k_mod = (s / quotient) % (M + 1); \
value = *(pseudo_data + d * pseudo_stride) * (kernel_size_data[d] - M * is_open_spline_data[d]); \
value -= floor(value); \
GRAD_CODE \
g = value; \
\
......
......@@ -4,7 +4,7 @@
void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
value = (1 - k_mod) * (1 - value) + k_mod * value;
value = 1 - value - k_mod + 2 * value * k_mod;
)
}
......@@ -18,16 +18,16 @@ void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_inde
void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0;
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6.0;
else value = value * value * value / 6.0;
if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6;
)
}
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(1, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = (1 - k_mod) * (1 - value) + k_mod * value;
value = 1 - value - k_mod + 2 * value * k_mod;
,
value = -1 + k_mod + k_mod;
)
......@@ -39,7 +39,7 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas
else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value;
,
if (k_mod == 0) value = 2 * value - 1;
if (k_mod == 0) value = value - 1;
else if (k_mod == 1) value = -2 * value + 1;
else value = value;
)
......@@ -47,9 +47,15 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(3, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = (1 - k_mod) * (1 - value) + k_mod * value;
if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6;
,
value = -(1 - k_mod) + k_mod;
if (k_mod == 0) value = (-value * value + 2 * value - 1) / 2;
else if (k_mod == 1) value = (3 * value * value - 4 * value) / 2;
else if (k_mod == 2) value = (-3 * value * value + 2 * value + 1) / 2;
else value = value * value / 2;
)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment