kernel.cu 1.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else

5
6
7
8
void cluster_(grid)(THCState *state, int C, THCudaLongTensor *output, THCTensor *position, THCTensor *size, THCudaLongTensor *count) {
  THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, position, size));
  THCAssertSameGPU(THCudaLongTensor_checkGPU(state, 2, output, count));
  THArgCheck(THCudaLongTensor_nDimension(state, output) <= MAX_DIMS, 1, "Tensor too large or too many dimensions");
rusty1s's avatar
rusty1s committed
9
10
11
12

  int64_t *outputData = THCudaLongTensor_data(state, output);
  TensorInfo<real> positionInfo = thc_(getTensorInfo)(state, position);
  real *sizeData = THCTensor_(data)(state, size);
13
  int64_t *countData = THCudaLongTensor_data(state, count);
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19

  const int N = THCudaLongTensor_nElement(state, output);
  int grid = GET_BLOCKS(N);
  cudaStream_t stream = THCState_getCurrentStream(state);

  switch (positionInfo.dims) {
20
21
22
23
24
    case  1: gridKernel<real,  1><<<grid, NUM_THREADS, 0, stream>>>(outputData, positionInfo, sizeData, countData, C, N); break;
    case  2: gridKernel<real,  2><<<grid, NUM_THREADS, 0, stream>>>(outputData, positionInfo, sizeData, countData, C, N); break;
    case  3: gridKernel<real,  3><<<grid, NUM_THREADS, 0, stream>>>(outputData, positionInfo, sizeData, countData, C, N); break;
    case  4: gridKernel<real,  4><<<grid, NUM_THREADS, 0, stream>>>(outputData, positionInfo, sizeData, countData, C, N); break;
    default: gridKernel<real, -1><<<grid, NUM_THREADS, 0, stream>>>(outputData, positionInfo, sizeData, countData, C, N); break;
rusty1s's avatar
rusty1s committed
25
26
27
  }

  THCudaCheck(cudaGetLastError());
rusty1s's avatar
rusty1s committed
28
29
30
}

#endif