Commit 2d394e78 authored by rusty1s's avatar rusty1s
Browse files

first fw try

parent 3a07cc5e
...@@ -51,6 +51,25 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor ...@@ -51,6 +51,25 @@ void spline_(basis_cubic)(THTensor *basis, THLongTensor *weight_index, THTensor
} }
void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_fw)(THTensor *output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
int64_t M_out = THTensor_(size)(output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i; real b, value;
TH_TENSOR_DIM_APPLY4(real, output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) {
value = 0;
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
value += b * *(weight_data + i * M_in * M_out + m_in * M_in + m_out) * *(input_data + m_in * input_stride);
}
}
output_data[m_out * output_stride] = value;
}
)
} }
void spline_(weighting_bw)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) { void spline_(weighting_bw)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment