Commit de11bfdf authored by rusty1s's avatar rusty1s
Browse files

swap m_in/m_out loop

parent 1de99c93
......@@ -17,7 +17,6 @@ def spline_conv(x,
print('TODO: Degree of 0')
print('TODO: Kernel size of 1')
print('swap M_in and M_out in backward implementation')
n, e = x.size(0), edge_index.size(1)
K, m_in, m_out = weight.size()
......
......@@ -36,7 +36,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)).fill_(0)
grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight.size()).fill_(0)
func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
......
......@@ -78,20 +78,22 @@ void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, TH
int64_t M_out = THTensor_(size)(grad_output, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g, b;
int64_t m_out, m_in, s, i, w_idx; real g_in, value, b, g_out;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) {
g = *(grad_output_data + m_out * grad_output_stride);
for (m_in = 0; m_in < M_in; m_in++) {
g_in = 0; value = *(input_data + m_in * input_stride);
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
for (m_out = 0; m_out < M_out; m_out++) {
w_idx = i * M_in * M_out + m_in * M_out + m_out;
grad_input_data[m_in] += b * g * *(weight_data + w_idx);
grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride);
g_out = *(grad_output_data + m_out * grad_output_stride);
grad_weight_data[w_idx] += b * g_out * value;
g_in += b * g_out * *(weight_data + w_idx);
}
}
grad_input_data[m_in] = g_in;
}
)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment