Commit e8e5841d authored by rusty1s's avatar rusty1s
Browse files

added basis grad for all degrees

parent 6100aa77
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import spline_conv
from torch_spline_conv.functions.spline_weighting import SplineWeighting from torch_spline_conv.functions.spline_weighting import SplineWeighting
# from torch_spline_conv.functions.ffi import implemented_degrees from torch_spline_conv.functions.ffi import implemented_degrees
from .utils import tensors, Tensor from .utils import tensors, Tensor
...@@ -49,7 +49,7 @@ def test_spline_conv_cpu(tensor): ...@@ -49,7 +49,7 @@ def test_spline_conv_cpu(tensor):
def test_spline_weighting_backward_cpu(): def test_spline_weighting_backward_cpu():
for degree in [1]: for degree in implemented_degrees.keys():
kernel_size = torch.LongTensor([5, 5]) kernel_size = torch.LongTensor([5, 5])
is_open_spline = torch.ByteTensor([1, 1]) is_open_spline = torch.ByteTensor([1, 1])
op = SplineWeighting(kernel_size, is_open_spline, degree) op = SplineWeighting(kernel_size, is_open_spline, degree)
......
...@@ -42,6 +42,8 @@ ...@@ -42,6 +42,8 @@
quotient = pow(M + 1, d); \ quotient = pow(M + 1, d); \
for (s = 0; s < S; s++) { \ for (s = 0; s < S; s++) { \
k_mod = (s / quotient) % (M + 1); \ k_mod = (s / quotient) % (M + 1); \
value = *(pseudo_data + d * pseudo_stride) * (kernel_size_data[d] - M * is_open_spline_data[d]); \
value -= floor(value); \
GRAD_CODE \ GRAD_CODE \
g = value; \ g = value; \
\ \
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(linear_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(1, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
value = (1 - k_mod) * (1 - value) + k_mod * value; value = 1 - value - k_mod + 2 * value * k_mod;
) )
} }
...@@ -18,16 +18,16 @@ void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_inde ...@@ -18,16 +18,16 @@ void spline_(quadratic_basis_forward)(THTensor *basis, THLongTensor *weight_inde
void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) { void spline_(cubic_basis_forward)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K) {
SPLINE_BASIS_FORWARD(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K, SPLINE_BASIS_FORWARD(3, basis, weight_index, pseudo, kernel_size, is_open_spline, K,
if (k_mod == 0) value = (1 - value) * (1 - value) * (1 - value) / 6.0; if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6.0; else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6.0; else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6.0; else value = value * value * value / 6;
) )
} }
void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(linear_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(1, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline, SPLINE_BASIS_BACKWARD(1, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = (1 - k_mod) * (1 - value) + k_mod * value; value = 1 - value - k_mod + 2 * value * k_mod;
, ,
value = -1 + k_mod + k_mod; value = -1 + k_mod + k_mod;
) )
...@@ -39,7 +39,7 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas ...@@ -39,7 +39,7 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas
else if (k_mod == 1) value = -value * value + value + 0.5; else if (k_mod == 1) value = -value * value + value + 0.5;
else value = 0.5 * value * value; else value = 0.5 * value * value;
, ,
if (k_mod == 0) value = 2 * value - 1; if (k_mod == 0) value = value - 1;
else if (k_mod == 1) value = -2 * value + 1; else if (k_mod == 1) value = -2 * value + 1;
else value = value; else value = value;
) )
...@@ -47,9 +47,15 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas ...@@ -47,9 +47,15 @@ void spline_(quadratic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_bas
void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) { void spline_(cubic_basis_backward)(THTensor *grad_pseudo, THTensor *grad_basis, THTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline) {
SPLINE_BASIS_BACKWARD(3, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline, SPLINE_BASIS_BACKWARD(3, grad_pseudo, grad_basis, pseudo, kernel_size, is_open_spline,
value = (1 - k_mod) * (1 - value) + k_mod * value; if (k_mod == 0) { value = (1 - value); value = value * value * value / 6.0; }
else if (k_mod == 1) value = (3 * value * value * value - 6 * value * value + 4) / 6;
else if (k_mod == 2) value = (-3 * value * value * value + 3 * value * value + 3 * value + 1) / 6;
else value = value * value * value / 6;
, ,
value = -(1 - k_mod) + k_mod; if (k_mod == 0) value = (-value * value + 2 * value - 1) / 2;
else if (k_mod == 1) value = (3 * value * value - 4 * value) / 2;
else if (k_mod == 2) value = (-3 * value * value + 2 * value + 1) / 2;
else value = value * value / 2;
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment