common.cuh 1002 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#ifndef THC_COMMON_INC
#define THC_COMMON_INC

#define KERNEL_LOOP(I, N) \
rusty1s's avatar
rusty1s committed
5
  for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < N; I += blockDim.x * gridDim.x)
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
const int MAX_DIMS = 25;
rusty1s's avatar
rusty1s committed
8
9
/* const int NUM_THREADS = 1024; */
const int NUM_THREADS = 256;
rusty1s's avatar
rusty1s committed
10

rusty1s's avatar
rusty1s committed
11
inline int GET_BLOCKS(int N) {
rusty1s's avatar
rusty1s committed
12
  return (N + NUM_THREADS - 1) / NUM_THREADS;
rusty1s's avatar
rusty1s committed
13
14
15
}

#define KERNEL_RUN(NAME, N, ...) \
rusty1s's avatar
rusty1s committed
16
17
18
19
20
21
  int grid = GET_BLOCKS(N); \
  cudaStream_t stream = THCState_getCurrentStream(state); \
  NAME<<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
  THCudaCheck(cudaGetLastError())

#define KERNEL_REAL_RUN(NAME, N, ...) \
rusty1s's avatar
rusty1s committed
22
23
24
25
26
  int grid = GET_BLOCKS(N); \
  cudaStream_t stream = THCState_getCurrentStream(state); \
  NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
  THCudaCheck(cudaGetLastError())

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
template<typename T>
struct TensorInfo {
  T *data;
  int dims;
  int size[MAX_DIMS];
  int stride[MAX_DIMS];
};

rusty1s's avatar
rusty1s committed
35
#include "generic/common.cuh"
rusty1s's avatar
rusty1s committed
36
#include "THC/THCGenerateAllTypes.h"
rusty1s's avatar
rusty1s committed
37
38

#endif  // THC_COMMON_INC