Commit 99a4ff83 authored by rusty1s's avatar rusty1s
Browse files

rename

parent de11bfdf
......@@ -55,16 +55,16 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
int64_t M_out = THTensor_(size)(output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i; real b, value;
int64_t m_out, m_in, s, w_idx; real b, value;
TH_TENSOR_DIM_APPLY4(real, output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) {
value = 0;
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
value += b * *(weight_data + i * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride);
value += b * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride);
}
}
output_data[m_out * output_stride] = value;
......@@ -78,19 +78,19 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g_in, value, b, g_out;
int64_t m_out, m_in, s, w_idx, idx; real g_in, value, b, g_out;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_in = 0; m_in < M_in; m_in++) {
g_in = 0; value = *(input_data + m_in * input_stride);
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_out = 0; m_out < M_out; m_out++) {
w_idx = i * M_in * M_out + m_in * M_out + m_out;
idx = w_idx * M_in * M_out + m_in * M_out + m_out;
g_out = *(grad_output_data + m_out * grad_output_stride);
grad_weight_data[w_idx] += b * g_out * value;
g_in += b * g_out * *(weight_data + w_idx);
grad_weight_data[idx] += b * g_out * value;
g_in += b * g_out * *(weight_data + idx);
}
}
grad_input_data[m_in] = g_in;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment