kernel.cu 1.42 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <THC/THC.h>

rusty1s's avatar
rusty1s committed
3
#include "THCAtomics.cuh"
4
#include "kernel.h"
rusty1s's avatar
rusty1s committed
5
#include "common.cuh"
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
8
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
rusty1s's avatar
rusty1s committed
9
#define check TH_CONCAT_2(check_kernel_, Real)
10

rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)

#include "generic/common.cu"
#include "THCGenerateAllTypes.h"

template <typename Real, int Dims>
rusty1s's avatar
rusty1s committed
17
__global__ void maxKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, TensorInfo<Real> input, TensorInfo<int64_t> arg, const int dim, const int n) {
rusty1s's avatar
rusty1s committed
18
  KERNEL_LOOP(i, n) {
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    int outputOffset = 0; int indexOffset = 0; int inputOffset = 0; int argOffset = 0;
    int curDimIndex;
    for (int d = index.dims - 1; d >= 0; d--) {
      curDimIndex = i % index.size[d];
      indexOffset += curDimIndex * index.stride[d];
      inputOffset += curDimIndex * input.stride[d];
      if (d != dim) {
        outputOffset += curDimIndex * output.stride[d];
        argOffset += curDimIndex * arg.stride[d];
      }
      i /= index.size[d];
    }
    int64_t indexValue = index.data[indexOffset];
    assert(indexValue >= 0 && indexValue < output.size[dim]);
    outputOffset += indexValue * output.stride[dim];
    argOffset += indexValue * arg.stride[dim];
rusty1s's avatar
rusty1s committed
35

rusty1s's avatar
rusty1s committed
36
37
    atomicMax(&output.data[outputOffset], input.data[inputOffset]);
    // TODO: Do something with arg.
rusty1s's avatar
rusty1s committed
38
39
  }
}
rusty1s's avatar
max dim  
rusty1s committed
40

rusty1s's avatar
rusty1s committed
41
42
#include "generic/kernel.cu"
#include "THCGenerateAllTypes.h"