#ifndef THC_GENERIC_FILE #define THC_GENERIC_FILE "generic/THCWeighting.cu" #else void THCTensor_(weightingForward)(THCState *state, THCTensor *self, THCTensor *src, THCTensor *weight, THCTensor *basis, THCudaLongTensor *weightIndex) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, src, weight, basis, weightIndex)); TensorInfo selfInfo = THCTensor_(getTensorInfo)(state, self); TensorInfo srcInfo = THCTensor_(getTensorInfo)(state, src); TensorInfo weightInfo = THCTensor_(getTensorInfo)(state, weight); TensorInfo basisInfo = THCTensor_(getTensorInfo)(state, basis); TensorInfo weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); KERNEL_REAL_RUN(weightingForwardKernel, THCTensor_(nElement)(state, self), selfInfo, srcInfo, weightInfo, basisInfo, weightIndexInfo); } void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput, THCTensor *weight, THCTensor *basis, THCudaLongTensor *weightIndex) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, weight, basis, weightIndex)); THCTensor *tweight = THCTensor_(new)(state); THCTensor_(transpose)(state, tweight, weight, 1, 2); weight = THCTensor_(newContiguous)(state, tweight); TensorInfo selfInfo = THCTensor_(getTensorInfo)(state, self); TensorInfo gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput); TensorInfo weightInfo = THCTensor_(getTensorInfo)(state, weight); TensorInfo basisInfo = THCTensor_(getTensorInfo)(state, basis); TensorInfo weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); KERNEL_REAL_RUN(weightingBackwardSrcKernel, THCTensor_(nElement)(state, self), selfInfo, gradOutputInfo, weightInfo, basisInfo, weightIndexInfo); } void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput, THCTensor *src, THCTensor *basis, THCudaLongTensor *weightIndex) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, basis, weightIndex)); THCTensor_(fill)(state, self, ScalarConvert::to(0)); TensorInfo selfInfo = THCTensor_(getTensorInfo)(state, self); TensorInfo gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput); TensorInfo srcInfo = THCTensor_(getTensorInfo)(state, src); TensorInfo basisInfo = THCTensor_(getTensorInfo)(state, basis); TensorInfo weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); KERNEL_REAL_RUN(weightingBackwardWeightKernel, THCTensor_(nElement)(state, gradOutput), selfInfo, gradOutputInfo, srcInfo, basisInfo, weightIndexInfo); } void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput, THCTensor *src, THCTensor *weight, THCudaLongTensor *weightIndex) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, gradOutput, src, weight, weightIndex)); THCTensor_(fill)(state, self, ScalarConvert::to(0)); TensorInfo selfInfo = THCTensor_(getTensorInfo)(state, self); TensorInfo gradOutputInfo = THCTensor_(getTensorInfo)(state, gradOutput); TensorInfo srcInfo = THCTensor_(getTensorInfo)(state, src); TensorInfo weightInfo = THCTensor_(getTensorInfo)(state, weight); TensorInfo weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); KERNEL_REAL_RUN(weightingBackwardBasisKernel, THCTensor_(nElement)(state, gradOutput), selfInfo, gradOutputInfo, srcInfo, weightInfo, weightIndexInfo); } #endif // THC_GENERIC_FILE