"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e4859c4563c3a0c3d7a34e7985ff2b8b41580b9c"
Commit f013eb53 authored by rusty1s's avatar rusty1s
Browse files

index backward impl

parent 4134714f
......@@ -49,6 +49,12 @@ def test_scatter_cuda_max(str):
output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1)
assert output.cpu().tolist() == expected_output
assert arg_output.cpu().tolist() == expected_arg_output
output, arg_output = scatter_max(index, input, dim=1)
assert output.cpu().tolist() == expected_output
assert arg_output.cpu().tolist() == expected_arg_output
output = Variable(output).fill_(0)
index = Variable(index)
......@@ -60,7 +66,4 @@ def test_scatter_cuda_max(str):
expected_grad_input = [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
output.backward(grad_output)
print(input.grad.data)
# print(input)
# assert input.grad.data.tolist() == expected_grad_input
assert input.grad.data.cpu().tolist() == expected_grad_input
......@@ -71,9 +71,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
template<typename Real, int Dims>
__global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> grad, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
/* int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0; */
/* IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, grad, &gradOffset, output, &outputOffset, arg, &argOffset); */
/* if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim]; */
int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0;
IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, output, &outputOffset, grad, &gradOffset, arg, &argOffset);
if (arg.data[argOffset] == outputOffset % output.size[dim]) output.data[outputOffset] = grad.data[gradOffset];
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment