Commit fbd0ffaf authored by rusty1s's avatar rusty1s
Browse files

added backward

parent 3fbeeabc
......@@ -36,8 +36,8 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1))
grad_weight = x.new(weight)
grad_input = x.new(x.size(0), weight.size(1)).fill_(0)
grad_weight = x.new(weight).fill_(0)
func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
return grad_input, grad_weight
......
......@@ -73,7 +73,27 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
}
void spline_(weighting_backward)(THTensor *grad_input, THTensor *grad_weight, THTensor *grad_output, THTensor *input, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset;
real *grad_weight_data = grad_weight->storage->data + grad_weight->storageOffset;
int64_t M_out = THTensor_(size)(grad_input, 1);
int64_t M_in = THTensor_(size)(input, 1);
int64_t S = THLongTensor_size(weight_index, 1);
int64_t m_out, m_in, s, i, w_idx; real g, b;
TH_TENSOR_DIM_APPLY5(real, grad_input, real, grad_output, real, input, real, basis, int64_t, weight_index, 1,
for (m_out = 0; m_out < M_out; m_out++) {
g = *(grad_output_data + m_out * grad_output_stride);
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
i = *(weight_index_data + s * weight_index_stride);
for (m_in = 0; m_in < M_in; m_in++) {
w_idx = i * M_in * M_out + m_in * M_out + m_out;
grad_input_data[m_in] += b * g * *(weight_data + w_idx);
grad_weight_data[w_idx] += b * g * *(input_data + m_in * input_stride);
}
}
}
)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment