kernel.cu 2.15 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else

rusty1s's avatar
rusty1s committed
5
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
6
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
7
8
9
10
  printf("mul");
}

void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
11
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
12
13
14
  printf("div");
}

rusty1s's avatar
rusty1s committed
15
void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) {
rusty1s's avatar
rusty1s committed
16
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
17
18
19
  printf("mean");
}

rusty1s's avatar
rusty1s committed
20
void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
21
22
23
24
25
26
  thc_(check)(state, output, index, input);

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
rusty1s's avatar
rusty1s committed
27
  TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
  KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
  KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
rusty1s's avatar
rusty1s committed
31
32
}

rusty1s's avatar
rusty1s committed
33
void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
34
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
  TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);

  KERNEL_RUN(minKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
  KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
rusty1s's avatar
rusty1s committed
44
45
}

rusty1s's avatar
rusty1s committed
46
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
47
  thc_(check)(state, output, index, grad);
rusty1s's avatar
rusty1s committed
48
49
50
51
  printf("index_backward");
}

#endif