Commit 865ef24d authored by rusty1s's avatar rusty1s
Browse files

backward gpu complete

parent 60ab8eea
......@@ -48,15 +48,37 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
}
}
template<typename T>
__global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
TensorInfo<T> src, TensorInfo<T> basis,
TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T b, v;
int64_t wi;
T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (s = 0; s < weightIndex.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
for (mIn = 0; mIn < src.size[1]; mIn++) {
v = src.data[e * src.stride[0] + mIn * src.stride[1]];
v = THCNumerics<T>::mul(v, b);
v = THCNumerics<T>::mul(v, g);
atomicAdd(&self.data[wi * self.stride[0] + mIn * self.stride[1] + mOut * self.stride[2]], v);
}
}
}
}
template<typename T>
__global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
TensorInfo<T> src, TensorInfo<T> weight,
TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T v, g, tmp;
T v, tmp;
int64_t wi;
g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (s = 0; s < weightIndex.size[1]; s++) {
v = ScalarConvert<int, T>::to(0);
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
......
......@@ -39,6 +39,18 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, basis, weightIndex));
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardWeightKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, srcInfo, basisInfo, weightIndexInfo);
}
void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
......
......@@ -72,9 +72,9 @@ def test_spline_basis_backward_gpu():
pseudo = torch.cuda.DoubleTensor(4, 2).uniform_(0, 1)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
src = Variable(src, requires_grad=False)
weight = Variable(weight, requires_grad=False)
basis = Variable(basis, requires_grad=True)
src = Variable(src, requires_grad=True)
weight = Variable(weight, requires_grad=True)
basis = Variable(basis, requires_grad=False)
op = SplineWeighting(weight_index)
assert gradcheck(op, (src, weight, basis), eps=1e-6, atol=1e-4) is True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment