"docs/vscode:/vscode.git/clone" did not exist on "449cc090254cb8eb878a3581e2cf7471838767e3"
Commit 4134714f authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent ac0825f5
......@@ -48,7 +48,19 @@ def test_scatter_cuda_max(str):
expected_arg_output = [[-1, -1, 3, 4, 0, 1], [1, 4, 3, -1, -1, -1]]
output, index, input = output.cuda(), index.cuda(), input.cuda()
_, arg_output = scatter_max_(output, index, input, dim=1)
print(output)
print(arg_output)
output = Variable(output).fill_(0)
index = Variable(index)
input = Variable(input, requires_grad=True)
scatter_max_(output, index, input, dim=1)
grad_output = [[10, 20, 30, 40, 50, 60], [15, 25, 35, 45, 55, 65]]
grad_output = Tensor(str, grad_output).cuda()
expected_grad_input = [[50, 60, 0, 30, 40], [0, 15, 0, 35, 25]]
output.backward(grad_output)
print(input.grad.data)
# print(input)
# assert input.grad.data.tolist() == expected_grad_input
......@@ -41,7 +41,8 @@ def _scatter(name, dim, *data):
def index_backward(dim, index, grad, arg_grad):
typename = type(grad).__name__.replace('Tensor', '')
func = getattr(ffi, 'index_backward_{}'.format(typename))
cuda = 'cuda_' if grad.is_cuda else ''
func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
output = grad.new(index.size()).fill_(0)
func(dim, output, index, grad, arg_grad)
return output
......
......@@ -64,7 +64,14 @@ void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
thc_(check)(state, output, index, grad);
printf("index_backward");
const int n = THCudaLongTensor_nElement(state, index);
TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
TensorInfo<real> gradInfo = thc_(getTensorInfo)(state, grad);
TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
KERNEL_RUN(indexBackwardKernel, indexInfo.dims, n, outputInfo, indexInfo, gradInfo, argInfo, dim)
}
#endif
......@@ -68,5 +68,14 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
}
}
template<typename Real, int Dims>
__global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> grad, TensorInfo<int64_t> arg, const int dim, const int n) {
KERNEL_LOOP(i, n) {
/* int outputOffset = 0; int indexOffset = 0; int gradOffset = 0; int argOffset = 0; */
/* IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, grad, &gradOffset, output, &outputOffset, arg, &argOffset); */
/* if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim]; */
}
}
#include "generic/kernel.cu"
#include "THCGenerateAllTypes.h"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment