Commit f7d3df7b authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 04dc2518
......@@ -6,8 +6,8 @@
#define THC_TENSOR_BASIS_BACKWARD(NAME, state, self, gradBasis, pseudo, kernelSize, \
isOpenSpline) { \
THCAssertSameGPU( \
THCTensor_(checkGPU)(state, 5, self, gradBasis, pseudo, kernelSize, isOpenSpline)); \
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradBasis, pseudo, kernelSize, \
isOpenSpline)); \
\
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self); \
TensorInfo<real> gradBasisInfo = THCTensor_(getTensorInfo)(state, gradBasis); \
......
......@@ -6,8 +6,8 @@
#define THC_TENSOR_BASIS_FORWARD(NAME, state, basis, weightIndex, pseudo, kernelSize, \
isOpenSpline) { \
THCAssertSameGPU( \
THCTensor_(checkGPU)(state, 5, basis, weightIndex, pseudo, kernelSize, isOpenSpline)); \
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, basis, weightIndex, pseudo, kernelSize, \
isOpenSpline)); \
\
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis); \
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); \
......
......@@ -4,6 +4,18 @@
#include "THCNumerics.cuh"
#include "THCAtomics.cuh"
#define TH_TENSOR_WEIGHTING(NAME, N, TENSOR1, TENSOR2, TENSOR3, TENSOR4, weightIndex) { \
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, TENSOR1, TENSOR2, TENSOR3, TENSOR4, weightIndex)); \
\
TensorInfo<real> tensor1Info = THCTensor_(getTensorInfo)(state, TENSOR1); \
TensorInfo<real> tensor2Info = THCTensor_(getTensorInfo)(state, TENSOR2); \
TensorInfo<real> tensor3Info = THCTensor_(getTensorInfo)(state, TENSOR3); \
TensorInfo<real> tensor4Info = THCTensor_(getTensorInfo)(state, TENSOR4); \
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); \
\
KERNEL_REAL_RUN(NAME, N, tensor1Info, tensor2Info, tensor3Info, tensor4Info, weightIndexInfo); \
}
template<typename T>
__global__ void weightingForwardKernel(TensorInfo<T> self, TensorInfo<T> src, TensorInfo<T> weight,
TensorInfo<T> basis, TensorInfo<int64_t> weightIndex,
......@@ -31,22 +43,20 @@ __global__ void weightingBackwardSrcKernel(TensorInfo<T> self, TensorInfo<T> gra
TensorInfo<T> weight, TensorInfo<T> basis,
TensorInfo<int64_t> weightIndex, int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / gradOutput.size[1], mOut = i % gradOutput.size[1], s, mIn;
T v, b, tmp;
ptrdiff_t e = i / self.size[1], mIn = i % self.size[1], s, mOut;
T v = ScalarConvert<int, T>::to(0), b, tmp;
int64_t wi;
T g = gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]];
for (mIn = 0; mIn < self.size[1]; mIn++) {
v = ScalarConvert<int, T>::to(0);
for (s = 0; s < basis.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
tmp = weight.data[wi * weight.stride[0] + mIn * weight.stride[1] + mOut * weight.stride[2]];
for (s = 0; s < basis.size[1]; s++) {
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
for (mOut = 0; mOut < gradOutput.size[1]; mOut++) {
tmp = weight.data[wi * weight.stride[0] + mOut * weight.stride[1] + mIn * weight.stride[2]];
tmp = THCNumerics<T>::mul(tmp, gradOutput.data[e * gradOutput.stride[0] + mOut * gradOutput.stride[1]]);
tmp = THCNumerics<T>::mul(tmp, b);
tmp = THCNumerics<T>::mul(tmp, g);
v = THCNumerics<T>::add(v, tmp);
}
atomicAdd(&self.data[e * self.stride[0] + mIn * self.stride[1]], v);
}
self.data[e * self.stride[0] + mIn * self.stride[1]] = v;
}
}
......
......@@ -5,67 +5,37 @@
void THCTensor_(weightingForward)(THCState *state, THCTensor *self, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, src, weight, basis, weightIndex));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> weightInfo = THCTensor_(getTensorInfo)(state, weight);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingForwardKernel, THCTensor_(nElement)(state, self), selfInfo, srcInfo,
weightInfo, basisInfo, weightIndexInfo);
TH_TENSOR_WEIGHTING(weightingForwardKernel, THCTensor_(nElement)(state, self), self, src, weight,
basis, weightIndex)
}
void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, weight, basis, weightIndex));
THCTensor *tWeight = THCTensor_(newTranspose)(state, weight, 1, 2);
weight = THCTensor_(newContiguous)(state, tWeight);
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TH_TENSOR_WEIGHTING(weightingBackwardSrcKernel, THCTensor_(nElement)(state, self), self,
gradOutput, weight, basis, weightIndex)
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> weightInfo = THCTensor_(getTensorInfo)(state, weight);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardSrcKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, weightInfo, basisInfo, weightIndexInfo);
THCTensor_(free)(state, tWeight);
THCTensor_(free)(state, weight);
}
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, basis, weightIndex));
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardWeightKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, srcInfo, basisInfo, weightIndexInfo);
TH_TENSOR_WEIGHTING(weightingBackwardWeightKernel, THCTensor_(nElement)(state, gradOutput), self,
gradOutput, src, basis, weightIndex)
}
void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, weight, weightIndex));
THCTensor_(fill)(state, self, ScalarConvert<int, real>::to(0));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> weightInfo = THCTensor_(getTensorInfo)(state, weight);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingBackwardBasisKernel, THCTensor_(nElement)(state, gradOutput), selfInfo,
gradOutputInfo, srcInfo, weightInfo, weightIndexInfo);
TH_TENSOR_WEIGHTING(weightingBackwardBasisKernel, THCTensor_(nElement)(state, gradOutput), self,
gradOutput, src, weight, weightIndex)
}
#endif // THC_GENERIC_FILE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment