Commit b5ac9f33 authored by rusty1s's avatar rusty1s
Browse files

weigthing forward (cpu+gpu)

parent d48533ea
void THFloatTensor_convForward( THFloatTensor *self, THFloatTensor *src, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convForward(THDoubleTensor *self, THDoubleTensor *src, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_convBackwardSrc( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardSrc(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_convBackwardBasis( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardBasis(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THLongTensor *weightIndex);
void THFloatTensor_convBackwardWeight( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardWeight(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *basis, THLongTensor *weightIndex);
#include <TH/TH.h>
#include "generic/THConv.c"
#include "generic/THWeighting.c"
#include "THGenerateFloatTypes.h"
void THFloatTensor_weightingForward( THFloatTensor *self, THFloatTensor *src, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingForward(THDoubleTensor *self, THDoubleTensor *src, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardSrc( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardSrc(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardWeight( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardWeight(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardBasis( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardBasis(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THLongTensor *weightIndex);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THConv.c"
#else
void THTensor_(convForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight, THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src, THTensor *weight, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src, THTensor *basis, THLongTensor *weightIndex) {
}
#endif // TH_GENERIC_FILE
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THWeighting.c"
#else
void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis,
THLongTensor *weightIndex) {
real *selfData = THTensor_(data)(self);
real *srcData = THTensor_(data)(src);
real *weightData = THTensor_(data)(weight);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real v, b;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(weight, 2); mOut++) {
v = 0;
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(weight, 1); mIn++) {
v += b * weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]] * srcData[e * src->stride[0] + mIn * src->stride[1]];
}
}
selfData[e * self->stride[0] + mOut * self->stride[1]] = v;
}
}
}
void THTensor_(weightingBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight,
THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(weightingBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *weight, THLongTensor *weightIndex) {
}
#endif // TH_GENERIC_FILE
#include "THCBasis.cu"
#include "THCConv.cu"
#include "THCWeighting.cu"
......@@ -2,7 +2,7 @@
#define THC_INC
#include "THCBasis.h"
#include "THCConv.h"
#include "THCWeighting.h"
#endif // THC_INC
#include "THCConv.h"
#include "generic/THCConv.cu"
#include "THC/THCGenerateFloatTypes.h"
#include "THCWeighting.h"
#include "common.cuh"
#include "THCNumerics.cuh"
template<typename T>
__global__ void weightingForwardKernel(TensorInfo<T> self, TensorInfo<T> src, TensorInfo<T> weight,
TensorInfo<T> basis, TensorInfo<int64_t> weightIndex,
int n) {
KERNEL_LOOP(i, n) {
ptrdiff_t e = i / self.size[1], mOut = i % self.size[1], s, mIn;
T v = ScalarConvert<int, T>::to(0), b, tmp;
int64_t wi;
for (s = 0; s < basis.size[1]; s++) {
b = basis.data[e * basis.stride[0] + s * basis.stride[1]];
wi = weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]];
for (mIn = 0; mIn < src.size[1]; mIn++) {
tmp = weight.data[wi * weight.stride[0] + mIn * weight.stride[1] + mOut * weight.stride[2]];
tmp = THCNumerics<T>::mul(tmp, b);
tmp = THCNumerics<T>::mul(tmp, src.data[e * src.stride[0] + mIn * src.stride[1]]);
v = THCNumerics<T>::add(v, tmp);
}
}
self.data[e * self.stride[0] + mOut * self.stride[1]] = v;
}
}
#include "generic/THCWeighting.cu"
#include "THC/THCGenerateFloatTypes.h"
#ifndef THC_CONV_INC
#define THC_CONV_INC
#ifndef THC_WEIGHTING_INC
#define THC_WEIGHTING_INC
#include <THC/THC.h>
......@@ -7,11 +7,11 @@
extern "C" {
#endif // __cplusplus
#include "generic/THCConv.h"
#include "generic/THCWeighting.h"
#include "THC/THCGenerateFloatTypes.h"
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // THC_CONV_INC
#endif // THC_WEIGHTING_INC
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCConv.cu"
#else
void THCTensor_(convForward)(THCState *state, THCTensor *self, THCTensor *src, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
}
void THCTensor_(convBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
}
void THCTensor_(convBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex) {
}
void THCTensor_(convBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex) {
}
#endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCConv.h"
#else
void THCTensor_(convForward)(THCState *state, THCTensor *self, THCTensor *src, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex);
void THCTensor_(convBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex);
void THCTensor_(convBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex);
void THCTensor_(convBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex);
#endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCWeighting.cu"
#else
void THCTensor_(weightingForward)(THCState *state, THCTensor *self, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 5, self, src, weight, basis, weightIndex));
TensorInfo<real> selfInfo = THCTensor_(getTensorInfo)(state, self);
TensorInfo<real> srcInfo = THCTensor_(getTensorInfo)(state, src);
TensorInfo<real> weightInfo = THCTensor_(getTensorInfo)(state, weight);
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis);
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex);
KERNEL_REAL_RUN(weightingForwardKernel, THCTensor_(nElement)(state, self), selfInfo, srcInfo,
weightInfo, basisInfo, weightIndexInfo);
}
void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
}
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex) {
}
void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex) {
}
#endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCWeighting.h"
#else
void THCTensor_(weightingForward)(THCState *state, THCTensor *self, THCTensor *src,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex);
void THCTensor_(weightingBackwardSrc)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex);
void THCTensor_(weightingBackwardWeight)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *basis,
THCudaLongTensor *weightIndex);
void THCTensor_(weightingBackwardBasis)(THCState *state, THCTensor *self, THCTensor *gradOutput,
THCTensor *src, THCTensor *weight,
THCudaLongTensor *weightIndex);
#endif // THC_GENERIC_FILE
void THCCFloatTensor_convForward( THCudaTensor *self, THCudaTensor *src, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convForward(THCudaDoubleTensor *self, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardSrc( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardSrc(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardBasis( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *weight, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardBasis(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardWeight( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardWeight(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
......@@ -6,6 +6,6 @@
extern THCState *state;
#include "generic/THCCConv.c"
#include "generic/THCCWeighting.c"
#include "THCGenerateFloatTypes.h"
void THCCFloatTensor_weightingForward( THCudaTensor *self, THCudaTensor *src, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingForward(THCudaDoubleTensor *self, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_weightingBackwardSrc( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingBackwardSrc(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_weightingBackwardWeight( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingBackwardWeight(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_weightingBackwardBasis( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *weight, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_weightingBackwardBasis(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaLongTensor *weightIndex);
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCConv.c"
#else
void THCCTensor_(convForward)(THCTensor *self, THCTensor *src, THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
THCTensor_(convForward)(state, self, src, weight, basis, weightIndex);
}
void THCCTensor_(convBackwardSrc)(THCTensor *self, THCTensor *gradOutput, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(convBackwardSrc)(state, self, gradOutput, weight, basis, weightIndex);
}
void THCCTensor_(convBackwardBasis)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCudaLongTensor *weightIndex) {
THCTensor_(convBackwardBasis)(state, self, gradOutput, src, weight, weightIndex);
}
void THCCTensor_(convBackwardWeight)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(convBackwardWeight)(state, self, gradOutput, src, basis, weightIndex);
}
#endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCWeighting.c"
#else
void THCCTensor_(weightingForward)(THCTensor *self, THCTensor *src, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingForward)(state, self, src, weight, basis, weightIndex);
}
void THCCTensor_(weightingBackwardSrc)(THCTensor *self, THCTensor *gradOutput, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardSrc)(state, self, gradOutput, weight, basis, weightIndex);
}
void THCCTensor_(weightingBackwardWeight)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *basis, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardWeight)(state, self, gradOutput, src, basis, weightIndex);
}
void THCCTensor_(weightingBackwardBasis)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCudaLongTensor *weightIndex) {
THCTensor_(weightingBackwardBasis)(state, self, gradOutput, src, weight, weightIndex);
}
#endif // THC_GENERIC_FILE
......@@ -8,7 +8,7 @@ from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Basis', 'Conv']
files = ['Basis', 'Weighting']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
sources = ['aten/TH/TH{}.c'.format(f) for f in files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment