"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "b06a1f2ba2441ad1e40dcd8a188528e291a6ab6d"
Commit 411e3e38 authored by rusty1s's avatar rusty1s
Browse files

arg check

parent 8b3f88a2
...@@ -2,31 +2,39 @@ ...@@ -2,31 +2,39 @@
#define THC_GENERIC_FILE "generic/kernel.cu" #define THC_GENERIC_FILE "generic/kernel.cu"
#else #else
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) { void check(THCState *state, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, output, input)); THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, output, input));
THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 2, index)); THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 2, index));
THArgCheck(THCTensor_(nDimension)(state, output) <= MAX_DIMS, 1, "Tensor too large or too many dimensions"); THArgCheck(THCTensor_(nDimension)(state, output) <= MAX_DIMS, 1, "Tensor too large or too many dimensions");
}
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
check(state, output, index, input);
printf("mul"); printf("mul");
} }
void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) { void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
check(state, output, index, input);
printf("div"); printf("div");
} }
void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *num_output) { void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *num_output) {
check(state, output, index, input);
printf("mean"); printf("mean");
} }
void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) { void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
check(state, output, index, input);
printf("max"); printf("max");
} }
void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) { void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
check(state, output, index, input);
printf("min"); printf("min");
} }
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg_grad) { void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg_grad) {
check(state, output, index, grad);
printf("index_backward"); printf("index_backward");
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real) #define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real) #define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
#define check TH_CONCAT_2(check_kernel_, Real)
#define MAX_DIMS 25 #define MAX_DIMS 25
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment