Commit c005e19d authored by rusty1s's avatar rusty1s
Browse files

beginning of basis computation

parent 07804abc
from .functions.spline_conv import spline_conv
__version__ = '0.1.0'
__all__ = ['__version__']
__all__ = ['spline_conv', '__version__']
......@@ -2,7 +2,7 @@ import torch
# from torch.autograd import Variable as Var
from .degree import node_degree
from .utils import spline_bases, spline_weighting
from .utils import spline_basis, spline_weighting
def spline_conv(x,
......@@ -21,8 +21,8 @@ def spline_conv(x,
output = x[index[1]]
# Get B-spline basis products and weight indices for each edge.
basis, weight_index = spline_bases(pseudo, kernel_size, is_open_spline,
degree)
basis, weight_index = spline_basis(degree, pseudo, kernel_size,
is_open_spline, weight.size(0))
# Weight gathered features based on B-spline basis and trainable weights.
output = spline_weighting(output, weight, basis, weight_index)
......
......@@ -3,6 +3,8 @@ from torch.autograd import Function
from .._ext import ffi
degrees = {1: 'linear', 2: 'quadric', 3: 'cubic'}
def get_func(name, tensor):
typename = type(tensor).__name__.replace('Tensor', '')
......@@ -11,9 +13,19 @@ def get_func(name, tensor):
return func
def spline_bases(pseudo, kernel_size, is_open_spline, degree):
# raise NotImplementedError for degree > 3
pass
def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
degree = degrees.get(degree)
if degree is None:
raise NotImplementedError('Basis computation not implemented for '
'specified B-spline degree')
s = (degree + 1)**kernel_size.size(0)
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
func = get_func('basis_{}', degree, pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index
def spline_weighting_forward(x, weight, basis, weight_index):
......
void spline_linear_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_linear_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_quadratic_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_quadratic_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_cubic_Float (THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_cubic_Double(THDoubleTensor *amount, THLongTensor *index, THDoubleTensor *input, THLongTensor *kernel, THByteTensor *open);
void spline_basis_linear_Float(THFloatTensor *basis, THLongTensor *weight_index, THFloatTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
void spline_basis_linear_Double(THDoubleTensor *basis, THLongTensor *weight_index, THDoubleTensor *pseudo, THLongTensor *kernel_size, THByteTensor *is_open_spline, int K);
......@@ -2,34 +2,62 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void spline_(linear)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) {
// s = (m+1)^d
// amount: E x s
// index: E x s
// input: E x d
// kernel: d
// open: d
//
int64_t i, d;
int64_t E = THLongTensor_size(index, 0);
int64_t K = THLongTensor_size(index, 1);
int64_t D = THLongTensor_size(kernel, 0);
for (i = 0; i < E * K; i++) {
for (d = 0; d < D; d++) {
}
}
}
void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor *pseudo, THTensor *kernel_size, THByteTensor *is_open_spline, int K) {
int64_t k, s, S, d, D;
real value;
D = THTensor_(size)(pseudo, 1);
S = THLongTEnsor_size(weight_index, 1);
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM,
for (s = 0; s < S; s++) {
/* k = K; */
/* b = 1; i = 0; */
for (d = 0; d < D; d++) {
/* k /= kernel_size[d]; */
void spline_(quadratic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) {
int64_t i;
for (i = 0; i < THLongTensor_size(input, dim); i++) {
}
}
void spline_(cubic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) {
int64_t i;
for (i = 0; i < THLongTensor_size(input, dim); i++) {
}
/* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); */
/* int bot = int64_t(value); */
/* int top = (bot + 1) % kernel_size[d]; */
/* bot %= kernel_size[d]; */
}
basis_data[s * basis_stride] = 1;
weight_index[s * weight_index_stride] = 2;
})
}
/* void spline_(linear)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* // s = (m+1)^d */
/* // amount: E x s */
/* // index: E x s */
/* // input: E x d */
/* // kernel: d */
/* // open: d */
/* // */
/* int64_t i, d; */
/* int64_t E = THLongTensor_size(index, 0); */
/* int64_t K = THLongTensor_size(index, 1); */
/* int64_t D = THLongTensor_size(kernel, 0); */
/* for (i = 0; i < E * K; i++) { */
/* for (d = 0; d < D; d++) { */
/* } */
/* } */
/* } */
/* void spline_(quadratic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
/* void spline_(cubic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment