Commit de4f89b7 authored by rusty1s's avatar rusty1s
Browse files

added conv headers

parent 24a25dd3
#include <TH/TH.h>
#include "generic/THConv.c"
#include "THGenerateFloatTypes.h"
void THFloatTensor_convForward( THFloatTensor *self, THFloatTensor *src, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convForward(THDoubleTensor *self, THDoubleTensor *src, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_convBackwardSrc( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardSrc(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_convBackwardBasis( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardBasis(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THLongTensor *weightIndex);
void THFloatTensor_convBackwardWeight( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_convBackwardWeight(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *basis, THLongTensor *weightIndex);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THConv.c"
#else
void THTensor_(convForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight, THTensor *basis, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src, THTensor *weight, THLongTensor *weightIndex) {
}
void THTensor_(convBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src, THTensor *basis, THLongTensor *weightIndex) {
}
#endif // TH_GENERIC_FILE
#include <THC/THC.h>
#include "THC.h"
#define THCCTensor_(NAME) TH_CONCAT_4(THCC,Real,Tensor_,NAME)
extern THCState *state;
#include "generic/THCCConv.c"
#include "THCGenerateFloatTypes.h"
void THCCFloatTensor_convForward( THCudaTensor *self, THCudaTensor *src, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convForward(THCudaDoubleTensor *self, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardSrc( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *weight, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardSrc(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *weight, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardBasis( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *weight, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardBasis(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *weight, THCudaLongTensor *weightIndex);
void THCCFloatTensor_convBackwardWeight( THCudaTensor *self, THCudaTensor *gradOutput, THCudaTensor *src, THCudaTensor *basis, THCudaLongTensor *weightIndex);
void THCCDoubleTensor_convBackwardWeight(THCudaDoubleTensor *self, THCudaDoubleTensor *gradOutput, THCudaDoubleTensor *src, THCudaDoubleTensor *basis, THCudaLongTensor *weightIndex);
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/THCCConv.c"
#else
void THCCTensor_(convForward)(THCTensor *self, THCTensor *src, THCTensor *weight, THCTensor *basis,
THCudaLongTensor *weightIndex) {
}
void THCCTensor_(convBackwardSrc)(THCTensor *self, THCTensor *gradOutput, THCTensor *weight,
THCTensor *basis, THCudaLongTensor *weightIndex) {
}
void THCCTensor_(convBackwardBasis)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *weight, THCudaLongTensor *weightIndex) {
}
void THCCTensor_(convBackwardWeight)(THCTensor *self, THCTensor *gradOutput, THCTensor *src,
THCTensor *basis, THCudaLongTensor *weightIndex) {
}
#endif // THC_GENERIC_FILE
......@@ -8,7 +8,7 @@ from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Basis']
files = ['Basis', 'Conv']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
sources = ['aten/TH/TH{}.c'.format(f) for f in files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment