kernel.cu 1.61 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else

rusty1s's avatar
rusty1s committed
5
void check(THCState *state, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
max dim  
rusty1s committed
6
7
8
  THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, output, input));
  THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 2, index));
  THArgCheck(THCTensor_(nDimension)(state, output) <= MAX_DIMS, 1, "Tensor too large or too many dimensions");
rusty1s's avatar
rusty1s committed
9
}
rusty1s's avatar
max dim  
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
12
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
  check(state, output, index, input);
rusty1s's avatar
rusty1s committed
13
14
15

  const ptrdiff_t n = THCudaLongTensor_nElement(state, index);
  const dim3 block = dim3(NUM_THREADS);
rusty1s's avatar
rusty1s committed
16
17
18
19
  printf("mul");
}

void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
20
  check(state, output, index, input);
rusty1s's avatar
rusty1s committed
21
22
23
24
  printf("div");
}

void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *num_output) {
rusty1s's avatar
rusty1s committed
25
  check(state, output, index, input);
rusty1s's avatar
rusty1s committed
26
27
28
29
  printf("mean");
}

void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
30
  check(state, output, index, input);
rusty1s's avatar
rusty1s committed
31
32
33
34
  printf("max");
}

void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
35
  check(state, output, index, input);
rusty1s's avatar
rusty1s committed
36
37
38
39
  printf("min");
}

void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg_grad) {
rusty1s's avatar
rusty1s committed
40
  check(state, output, index, grad);
rusty1s's avatar
rusty1s committed
41
42
43
44
  printf("index_backward");
}

#endif