Commit 3472bc29 authored by rusty1s's avatar rusty1s
Browse files

debug

parent 143a57ec
...@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor ...@@ -7,7 +7,7 @@ from .utils import tensor_strs, Tensor
# @pytest.mark.parametrize('str', tensor_strs) # @pytest.mark.parametrize('str', tensor_strs)
@pytest.mark.parametrize('str', ['DoubleTensor']) @pytest.mark.parametrize('str', ['IntTensor'])
def test_scatter_mean(str): def test_scatter_mean(str):
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]] input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]] index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
...@@ -25,22 +25,23 @@ def test_scatter_mean(str): ...@@ -25,22 +25,23 @@ def test_scatter_mean(str):
assert output.tolist() == expected_output assert output.tolist() == expected_output
assert output_index.tolist() == expected_output_index assert output_index.tolist() == expected_output_index
output = Variable(output).fill_(0) # output = Variable(output).fill_(0)
index = Variable(index) # index = Variable(index)
input = Variable(input, requires_grad=True) # input = Variable(input, requires_grad=True)
scatter_max_(output, index, input, dim=1) # scatter_max_(output, index, input, dim=1)
grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]] # grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]]
grad_output = Tensor(str, grad_output) # grad_output = Tensor(str, grad_output)
output.backward(grad_output) # output.backward(grad_output)
# assert index.data.tolist() == input.grad.data.tolist() # assert index.data.tolist() == input.grad.data.tolist()
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0])) # output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
index = Variable(torch.LongTensor([3, 4, 4, 2, 1])) index = Variable(torch.LongTensor([3, 4, 4, 2, 1]))
input = Variable(torch.FloatTensor([1, 2, 3, 4, 5]), requires_grad=True) input = Variable(torch.IntTensor([1, 2, 3, 4, 5]), requires_grad=True)
output, output_index = scatter_max(index, input) output, output_index = scatter_max(index, input)
# print(output, output_index)
# print(output_index) # print(output_index)
output.backward(torch.FloatTensor([10, 20, 30, 40])) output.backward(torch.IntTensor([10, 20, 30, 40]))
print(input.grad) # print(input.grad)
...@@ -66,11 +66,12 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp ...@@ -66,11 +66,12 @@ void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *inp
} }
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) { void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *grad_index) {
int64_t idx;
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim, TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, grad_index, dim,
for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) { for (int64_t i = 0; i < THLongTensor_size(index, dim); i++) {
int64_t idx = *(index_data + i * index_stride); idx = *(index_data + i * index_stride);
/* if (grad_index_data[index_data[i]] == i) { */ /* if (grad_index_data[index_data[i]] == i) { */
/* printf("i: %i, ", i); */ printf("i: %lli, idx: %lli grad_index: %i grad: %i \n", i, idx, *(grad_index_data + idx * grad_index_stride), *(grad_data + idx * grad_stride));
/* output_data[i] = grad_data[idx]; */ /* output_data[i] = grad_data[idx]; */
/* } */ /* } */
}) })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment