kernel.cu 3.19 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else

rusty1s's avatar
rusty1s committed
5
void scatter_(mul)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
6
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);

  KERNEL_RUN(mulKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
rusty1s's avatar
rusty1s committed
14
15
16
}

void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
17
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);

  KERNEL_RUN(divKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
rusty1s's avatar
rusty1s committed
25
26
}

rusty1s's avatar
rusty1s committed
27
void scatter_(mean)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) {
rusty1s's avatar
rusty1s committed
28
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
  TensorInfo<real> countInfo = thc_(getTensorInfo)(state, count);

  KERNEL_RUN(meanKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, countInfo, dim)
rusty1s's avatar
rusty1s committed
37
38
}

rusty1s's avatar
rusty1s committed
39
void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
  thc_(check)(state, output, index, input);

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
rusty1s's avatar
rusty1s committed
46
  TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
49
  KERNEL_RUN(maxKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
  KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
rusty1s's avatar
rusty1s committed
50
51
}

rusty1s's avatar
rusty1s committed
52
void scatter_(min)(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
53
  thc_(check)(state, output, index, input);
rusty1s's avatar
rusty1s committed
54
55
56
57
58
59
60
61
62

  const int n = THCudaLongTensor_nElement(state, index);
  TensorInfo<real> outputInfo = thc_(getTensorInfo)(state, output);
  TensorInfo<int64_t> indexInfo = thc_getTensorInfo_Long(state, index);
  TensorInfo<real> inputInfo = thc_(getTensorInfo)(state, input);
  TensorInfo<int64_t> argInfo = thc_getTensorInfo_Long(state, arg);

  KERNEL_RUN(minKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, dim)
  KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, inputInfo, argInfo, dim)
rusty1s's avatar
rusty1s committed
63
64
}

rusty1s's avatar
rusty1s committed
65
void index_backward(THCState *state, int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
rusty1s's avatar
rusty1s committed
66
  thc_(check)(state, output, index, grad);
rusty1s's avatar
rusty1s committed
67
68
69
70
  printf("index_backward");
}

#endif