Commit de11bfdf authored by rusty1s's avatar rusty1s
Browse files

swap m_in/m_out loop

parent 1de99c93
...@@ -17,7 +17,6 @@ def spline_conv(x, ...@@ -17,7 +17,6 @@ def spline_conv(x,
print('TODO: Degree of 0') print('TODO: Degree of 0')
print('TODO: Kernel size of 1') print('TODO: Kernel size of 1')
print('swap M_in and M_out in backward implementation')
n, e = x.size(0), edge_index.size(1) n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size() K, m_in, m_out = weight.size()
......
...@@ -36,7 +36,7 @@ def spline_weighting_forward(x, weight, basis, weight_index): ...@@ -36,7 +36,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index): def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)).fill_(0) grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight.size()).fill_(0) grad_weight = x.new(weight.size()).fill_(0)
func = get_func('weighting_backward', x) func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index) func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
......
...@@ -78,20 +78,22 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH ...@@ -78,20 +78,22 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
int64_t M_out = THTensor_(size)(grad_output, 1); int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1); int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1); int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g, b; int64_t m_out, m_in, s, i, w_idx; real g_in, value, b, g_out;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1, TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) { for (m_in = 0; m_in < M_in; m_in++) {
g = *(grad_output_data + m_out * grad_output_stride); g_in = 0; value = *(input_data + m_in * input_stride);
for (s = 0; s < S; s++) { for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride); b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride); i = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) { for (m_out = 0; m_out < M_out; m_out++) {
w_idx = i * M_in * M_out + m_in * M_out + m_out; w_idx = i * M_in * M_out + m_in * M_out + m_out;
grad_input_data[m_in] += b * g * *(weight_data + w_idx); g_out = *(grad_output_data + m_out * grad_output_stride);
grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride); grad_weight_data[w_idx] += b * g_out * value;
g_in += b * g_out * *(weight_data + w_idx);
} }
} }
grad_input_data[m_in] = g_in;
} }
) )
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment