kernel.cu 945 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <THC/THC.h>

3
#include "kernel.h"
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
#include "common.cuh"
rusty1s's avatar
rusty1s committed
6
7
#include "THCIndex.cuh"
#include "THCAtomics.cuh"
rusty1s's avatar
rusty1s committed
8

rusty1s's avatar
rusty1s committed
9
10
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
rusty1s's avatar
rusty1s committed
11
12
13
14
15
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)

#include "generic/common.cu"
#include "THCGenerateAllTypes.h"

rusty1s's avatar
rusty1s committed
16
template<typename Real, int Dims>
rusty1s's avatar
rusty1s committed
17
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
rusty1s's avatar
rusty1s committed
18
  KERNEL_LOOP(i, n) {
rusty1s's avatar
rusty1s committed
19
20
    int outputOffset = 0; int indexOffset = 0; int inputOffset = 0;
    IndexToScatterOffsets<Real, Real, Dims>::compute(i, dim, index, &indexOffset, input, &inputOffset, output, &outputOffset);
rusty1s's avatar
rusty1s committed
21
22
    atomicMax(&output.data[outputOffset], input.data[inputOffset]);
    // TODO: Do something with arg.
rusty1s's avatar
rusty1s committed
23
24
  }
}
rusty1s's avatar
max dim  
rusty1s committed
25

rusty1s's avatar
rusty1s committed
26
27
#include "generic/kernel.cu"
#include "THCGenerateAllTypes.h"