Commit 99a4ff83 authored by rusty1s's avatar rusty1s
Browse files

rename

parent de11bfdf
...@@ -55,16 +55,16 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei ...@@ -55,16 +55,16 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
int64_t M_out = THTensor_(size)(output, 1); int64_t M_out = THTensor_(size)(output, 1);
int64_t M_in = THTensor_(size)(input, 1); int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1); int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i; real b, value; int64_t m_out, m_in, s, w_idx; real b, value;
TH_TENSOR_DIM_APPLY4(real, output, real, input, real, basis, int64_t, weight_index, 1, TH_TENSOR_DIM_APPLY4(real, output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
value = 0; value = 0;
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride); b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride); w_idx = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
value += b * *(weight_data + i * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride); value += b * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out) * *(input_data + m_in * input_stride);
} }
} }
output_data[m_out * output_stride] = value; output_data[m_out * output_stride] = value;
...@@ -78,19 +78,19 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH ...@@ -78,19 +78,19 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
int64_t M_out = THTensor_(size)(grad_output, 1); int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1); int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1); int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g_in, value, b, g_out; int64_t m_out, m_in, s, w_idx, idx; real g_in, value, b, g_out;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1, TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_in = 0; m_in < M_in; m_in++) { for (m_in = 0; m_in < M_in; m_in++) {
g_in = 0; value = *(input_data + m_in * input_stride); g_in = 0; value = *(input_data + m_in * input_stride);
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride); b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride); w_idx = *(weight_index_data + s * weight_index_stride);
for (m_out = 0; m_out < M_out; m_out++) { for (m_out = 0; m_out < M_out; m_out++) {
w_idx = i * M_in * M_out + m_in * M_out + m_out; idx = w_idx * M_in * M_out + m_in * M_out + m_out;
g_out = *(grad_output_data + m_out * grad_output_stride); g_out = *(grad_output_data + m_out * grad_output_stride);
grad_weight_data[w_idx] += b * g_out * value; grad_weight_data[idx] += b * g_out * value;
g_in += b * g_out * *(weight_data + w_idx); g_in += b * g_out * *(weight_data + idx);
} }
} }
grad_input_data[m_in] = g_in; grad_input_data[m_in] = g_in;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment