Commit f013eb53 authored by rusty1s's avatar rusty1s
Browse files

index backward impl

parent 4134714f
...@@ -49,6 +49,12 @@ def test_scatter_cuda_max(str): ...@@ -49,6 +49,12 @@ def test_scatter_cuda_max(str):
output, index, input = output.cuda(), index.cuda(), input.cuda() output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1) _, arg_output = scatter_max_(output, index, input, dim=1)
assert output.cpu().tolist() == expected_output
assert arg_output.cpu().tolist() == expected_arg_output
output, arg_output = scatter_max(index, input, dim=1)
assert output.cpu().tolist() == expected_output
assert arg_output.cpu().tolist() == expected_arg_output
output = Variable(output).fill_(0) output = Variable(output).fill_(0)
index = Variable(index) index = Variable(index)
...@@ -60,7 +66,4 @@ def test_scatter_cuda_max(str): ...@@ -60,7 +66,4 @@ def test_scatter_cuda_max(str):
expected_grad_input = [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]] expected_grad_input = [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
output.backward(grad_output) output.backward(grad_output)
print(input.grad.data) assert input.grad.data.cpu().tolist() == expected_grad_input
# print(input)
# assert input.grad.data.tolist() == expected_grad_input
...@@ -71,9 +71,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te ...@@ -71,9 +71,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
template<typename Real, int Dims> template<typename Real, int Dims>
__global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> grad, TensorInfo<int64_t> arg, const int dim, const int n) { __global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> grad, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) { KERNEL_LOOP(i, n) {
/* int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0; */ int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0;
/* IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, grad, &gradOffset, output, &outputOffset, arg, &argOffset); */ IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, output, &outputOffset, grad, &gradOffset, arg, &argOffset);
/* if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim]; */ if (arg.data[argOffset] == outputOffset % output.size[dim]) output.data[outputOffset] = grad.data[gradOffset];
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment