kernel.cu 1.94 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else

rusty1s's avatar
rusty1s committed
5
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
6
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
7
8
9
10
  printf("mul");
}

void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
11
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
12
13
14
15
  printf("div");
}

void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *num_output) {
rusty1s's avatar
rusty1s committed
16
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
17
18
19
20
  printf("mean");
}

void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
  thc_(check)(state, output, index, input);

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
rusty1s's avatar
rusty1s committed
27
  TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg_output);
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
31
32
33
  KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
  /* KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim) */

  /* maxKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(outputInfo, indexInfo, inputInfo, dim, n); */
  /* argKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(dim, n); */
rusty1s's avatar
rusty1s committed
34
35
36
}

void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
37
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
38
39
40
41
  printf("min");
}

void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg_grad) {
rusty1s's avatar
rusty1s committed
42
  thc_(check)(state, output, index, grad);
rusty1s's avatar
rusty1s committed
43
44
45
46
  printf("index_backward");
}

#endif