Commit cf245f4f authored by rusty1s's avatar rusty1s
Browse files

comment

parent ab832e02
...@@ -37,6 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index): ...@@ -37,6 +37,7 @@ def spline_weighting_forward(x, weight, basis, weight_index):
def spline_weighting_backward(grad_output, x, weight, basis, weight_index): def spline_weighting_backward(grad_output, x, weight, basis, weight_index):
grad_input = x.new(x.size(0), weight.size(1)) grad_input = x.new(x.size(0), weight.size(1))
# grad_weight computation via `atomic_add` => Initialize with zeros.
grad_weight = x.new(weight.size()).fill_(0) grad_weight = x.new(weight.size()).fill_(0)
func = get_func('weighting_backward', x) func = get_func('weighting_backward', x)
func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index) func(grad_input, grad_weight, grad_output, x, weight, basis, weight_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment