kernel.cu 1.88 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <THC/THC.h>

3
#include "kernel.h"
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
#include "common.cuh"
rusty1s's avatar
rusty1s committed
6
7
#include "THCIndex.cuh"
#include "THCAtomics.cuh"
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)

#include "generic/common.cu"
#include "THCGenerateAllTypes.h"

rusty1s's avatar
rusty1s committed
16
template<typename Real, int Dims>
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
  KERNEL_LOOP(i, n) {
    int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
    IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
    atomicMax(&output.data[outputOffset], input.data[inputOffset]);
  }
}

template<typename Real, int Dims>
__global__ void minKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, const int dim, const int n) {
  KERNEL_LOOP(i, n) {
    int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;;
    IndexToScatterOffsets3<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
    atomicMin(&output.data[outputOffset], input.data[inputOffset]);
  }
}

template<typename Real, int Dims>
__global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
rusty1s's avatar
rusty1s committed
36
  KERNEL_LOOP(i, n) {
rusty1s's avatar
rusty1s committed
37
38
    int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
    IndexToScatterOffsets4<Real, Real, int64_t, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset, arg, &argOffset);
rusty1s's avatar
rusty1s committed
39
    if (eq(input.data[inputOffset], output.data[outputOffset])) arg.data[argOffset] = inputOffset % input.size[dim];
rusty1s's avatar
rusty1s committed
40
41
  }
}
rusty1s's avatar
max dim  
rusty1s committed
42

rusty1s's avatar
rusty1s committed
43
44
#include "generic/kernel.cu"
#include "THCGenerateAllTypes.h"