Commit c005e19d authored by rusty1s's avatar rusty1s
Browse files

beginning of basis computation

parent 07804abc
from .functions.spline_conv import spline_conv
__version__ = '0.1.0' __version__ = '0.1.0'
__all__ = ['__version__'] __all__ = ['spline_conv', '__version__']
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
# from torch.autograd import Variable as Var # from torch.autograd import Variable as Var
from .degree import node_degree from .degree import node_degree
from .utils import spline_bases, spline_weighting from .utils import spline_basis, spline_weighting
def spline_conv(x, def spline_conv(x,
...@@ -21,8 +21,8 @@ def spline_conv(x, ...@@ -21,8 +21,8 @@ def spline_conv(x,
output = x[index[1]] output = x[index[1]]
# Get B-spline basis products and weight indices for each edge. # Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_bases(pseudo, kernel_size, is_open_spline, basis, weight_index = spline_basis(degree, pseudo, kernel_size,
degree) is_open_spline, weight.size(0))
# Weight gathered features based on B-spline basis and trainable weights. # Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index) output = spline_weighting(output, weight, basis, weight_index)
......
...@@ -3,6 +3,8 @@ from torch.autograd import Function ...@@ -3,6 +3,8 @@ from torch.autograd import Function
from .._ext import ffi from .._ext import ffi
degrees = {1: 'linear', 2: 'quadric', 3: 'cubic'}
def get_func(name, tensor): def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '') typename = type(tensor).__name__.replace('Tensor', '')
...@@ -11,9 +13,19 @@ def get_func(name, tensor): ...@@ -11,9 +13,19 @@ def get_func(name, tensor):
return func return func
def spline_bases(pseudo, kernel_size, is_open_spline, degree): def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
# raise NotImplementedError for degree > 3 degree = degrees.get(degree)
pass if degree is None:
raise NotImplementedError('Basis computation not implemented for '
'specified B-spline degree')
s = (degree + 1)**kernel_size.size(0)
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
func = get_func('basis_{}', degree, pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index
def spline_weighting_forward(x, weight, basis, weight_index): def spline_weighting_forward(x, weight, basis, weight_index):
......
void spline_linear_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open); void spline_basis_linear_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_linear_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open); void spline_basis_linear_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_quadratic_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_quadratic_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_cubic_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_cubic_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open);
...@@ -2,34 +2,62 @@ ...@@ -2,34 +2,62 @@
#define TH_GENERIC_FILE "generic/cpu.c" #define TH_GENERIC_FILE "generic/cpu.c"
#else #else
void spline_(linear)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THTensor *kernel_size, THByteTensor *is_open_spline, int K) {
// s = (m+1)^d int64_t k, s, S, d, D;
// amount: E x s real value;
// index: E x s D = THTensor_(size)(pseudo, 1);
// input: E x d S = THLongTEnsor_size(weight_index, 1);
// kernel: d TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM,
// open: d for (s = 0; s < S; s++) {
// /* k = K; */
int64_t i, d; /* b = 1; i = 0; */
int64_t E = THLongTensor_size(index, 0);
int64_t K = THLongTensor_size(index, 1); for (d = 0; d < D; d++) {
int64_t D = THLongTensor_size(kernel, 0); /* k /= kernel_size[d]; */
for (i = 0; i < E * K; i++) {
for (d = 0; d < D; d++) {
}
}
}
void spline_(quadratic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) {
int64_t i;
for (i = 0; i < THLongTensor_size(input, dim); i++) {
}
}
void spline_(cubic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { /* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); */
int64_t i;
for (i = 0; i < THLongTensor_size(input, dim); i++) { /* int bot = int64_t(value); */
} /* int top = (bot + 1) % kernel_size[d]; */
/* bot %= kernel_size[d]; */
}
basis_data[s * basis_stride] = 1;
weight_index[s * weight_index_stride] = 2;
})
} }
/* void spline_(linear)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* // s = (m+1)^d */
/* // amount: E x s */
/* // index: E x s */
/* // input: E x d */
/* // kernel: d */
/* // open: d */
/* // */
/* int64_t i, d; */
/* int64_t E = THLongTensor_size(index, 0); */
/* int64_t K = THLongTensor_size(index, 1); */
/* int64_t D = THLongTensor_size(kernel, 0); */
/* for (i = 0; i < E * K; i++) { */
/* for (d = 0; d < D; d++) { */
/* } */
/* } */
/* } */
/* void spline_(quadratic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
/* void spline_(cubic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment