cuda.c 1.13 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/cuda.c"
#else

void scatter_(mul)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
6
  scatter_kernel_(mul)(state, dim, output, index, input);
rusty1s's avatar
rusty1s committed
7
8
9
}

void scatter_(div)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
rusty1s's avatar
rusty1s committed
10
  scatter_kernel_(div)(state, dim, output, index, input);
rusty1s's avatar
rusty1s committed
11
12
13
}

void scatter_(mean)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *num_output) {
rusty1s's avatar
rusty1s committed
14
  scatter_kernel_(mean)(state, dim, output, index, input, num_output);
rusty1s's avatar
rusty1s committed
15
16
17
}

void scatter_(max)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
18
  scatter_kernel_(max)(state, dim, output, index, input, arg_output);
rusty1s's avatar
rusty1s committed
19
20
21
}

void scatter_(min)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg_output) {
rusty1s's avatar
rusty1s committed
22
  scatter_kernel_(min)(state, dim, output, index, input, arg_output);
rusty1s's avatar
rusty1s committed
23
24
25
}

void index_backward(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg_grad) {
rusty1s's avatar
rusty1s committed
26
  index_backward_kernel(state, dim, output, index, grad, arg_grad);
rusty1s's avatar
rusty1s committed
27
28
29
}

#endif