Commit e3940621 authored by rusty1s's avatar rusty1s
Browse files

grad test

parent 29e28847
......@@ -24,18 +24,8 @@ def test_scatter_add():
input = Variable(input, requires_grad=True)
scatter_add_(output, index, input, dim=1)
c = output.sum()
c.backward()
grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
grad_output = torch.FloatTensor(grad_output)
# # a = input * 2
# # b = output * 2
# a = input * 2
# b = output * 2
# ScatterAdd(1)(b, index, a)
# # b.scatter_add_(1, index, a)
# c = b.sum()
# c.backward()
# print(input.grad)
# print(output.grad)
output.backward(grad_output)
assert_equal(index.data.tolist(), input.grad.data.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment