Commit 5ba5c620 authored by rusty1s's avatar rusty1s
Browse files

backward bugfix

parent 3472bc29
...@@ -6,8 +6,7 @@ from torch_scatter import scatter_max_, scatter_max ...@@ -6,8 +6,7 @@ from torch_scatter import scatter_max_, scatter_max
from .utils import tensor_strs, Tensor from .utils import tensor_strs, Tensor
# @pytest.mark.parametrize('str', tensor_strs) @pytest.mark.parametrize('str', tensor_strs)
@pytest.mark.parametrize('str', ['IntTensor'])
def test_scatter_mean(str): def test_scatter_mean(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
...@@ -25,23 +24,14 @@ def test_scatter_mean(str): ...@@ -25,23 +24,14 @@ def test_scatter_mean(str):
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index assert output_index.tolist() == expected_output_index
# output = Variable(output).fill_(0) output = Variable(output).fill_(0)
# index = Variable(index) index = Variable(index)
# input = Variable(input, requires_grad=True) input = Variable(input, requires_grad=True)
# scatter_max_(output, index, input, dim=1) scatter_max_(output, index, input, dim=1)
# grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]] grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]]
# grad_output = Tensor(str, grad_output) grad_output = Tensor(str, grad_output)
expected_grad_input = [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
# output.backward(grad_output) output.backward(grad_output)
# assert index.data.tolist() == input.grad.data.tolist() assert input.grad.data.tolist() == expected_grad_input
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
index = Variable(torch.LongTensor([3, 4, 4, 2, 1]))
input = Variable(torch.IntTensor([1, 2, 3, 4, 5]), requires_grad=True)
output, output_index = scatter_max(index, input)
# print(output, output_index)
# print(output_index)
output.backward(torch.IntTensor([10, 20, 30, 40]))
# print(input.grad)
...@@ -66,14 +66,9 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp ...@@ -66,14 +66,9 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
} }
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) { void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) {
int64_t idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
idx = *(index_data + i * index_stride); if (grad_index_data[index_data[i]] == i) output_data[i] = grad_data[index_data[i]];
/* if (grad_index_data[index_data[i]] == i) { */
printf("i: %lli, idx: %lli grad_index: %i grad: %i \n", i, idx, *(grad_index_data + idx * grad_index_stride), *(grad_data + idx * grad_stride));
/* output_data[i] = grad_data[idx]; */
/* } */
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment