Commit 865ef24d authored by rusty1s's avatar rusty1s
Browse files

backward gpu complete

parent 60ab8eea
...@@ -48,15 +48,37 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra ...@@ -48,15 +48,37 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
} }
} }
template<typename T>
__global__ void weightingBackwardWeightKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
TensorInfo<T> src, TensorInfo<T> basis,
TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T b, v;
int64_t wi;
T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (s = 0; s < weightIndex.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
for (mIn = 0; mIn < src.size[1]; mIn++) {
v = src.data[e * src.stride[0] + mIn * src.stride[1]];
v = THCNumerics<T>::mul(v, b);
v = THCNumerics<T>::mul(v, g);
atomicAdd(&self.data[wi * self.stride[0] + mIn * self.stride[1] + mOut * self.stride[2]], v);
}
}
}
}
template<typename T> template<typename T>
__global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> gradOutput, __global__ void weightingBackwardBasisKernel(TensorInfo<T> self, TensorInfo<T> gradOutput,
TensorInfo<T> src, TensorInfo<T> weight, TensorInfo<T> src, TensorInfo<T> weight,
TensorInfo<int64_t> weightIndex, int n) { TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) { KERNEL_LOOP(i, n) {
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn; ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T v, g, tmp; T v, tmp;
int64_t wi; int64_t wi;
g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]]; T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (s = 0; s < weightIndex.size[1]; s++) { for (s = 0; s < weightIndex.size[1]; s++) {
v = ScalarConvert<int, T>::to(0); v = ScalarConvert<int, T>::to(0);
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]]; wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
......
...@@ -39,6 +39,18 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso ...@@ -39,6 +39,18 @@ void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTenso
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput, void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis, THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex) { THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, basis, weightIndex));
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardWeightKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, srcInfo, basisInfo, weightIndexInfo);
} }
void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput, void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
......
...@@ -72,9 +72,9 @@ def test_spline_basis_backward_gpu(): ...@@ -72,9 +72,9 @@ def test_spline_basis_backward_gpu():
pseudo = torch.cuda.DoubleTensor(4, 2).uniform_(0, 1) pseudo = torch.cuda.DoubleTensor(4, 2).uniform_(0, 1)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline) basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
src = Variable(src, requires_grad=False) src = Variable(src, requires_grad=True)
weight = Variable(weight, requires_grad=False) weight = Variable(weight, requires_grad=True)
basis = Variable(basis, requires_grad=True) basis = Variable(basis, requires_grad=False)
op = SplineWeighting(weight_index) op = SplineWeighting(weight_index)
assert gradcheck(op, (src, weight, basis), eps=1e-6, atol=1e-4) is True assert gradcheck(op, (src, weight, basis), eps=1e-6, atol=1e-4) is True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment