common.cuh 1.71 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
const int MAX_DIMS = 25;
const int NUM_THREADS = 1024;

inline int GET_BLOCKS(const int n) {
  return (n + NUM_THREADS - 1) / NUM_THREADS;
}

rusty1s's avatar
rusty1s committed
8
template<typename T>
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
struct TensorInfo {
  TensorInfo(T *t, int d, int sz[MAX_DIMS], int st[MAX_DIMS]) {
    data = t; dims = d;
    for (int i = 0; i < dims; i++) {
      size[i] = sz[i];
      stride[i] = st[i];
    }
  }

  T *data;
  int dims;
  int size[MAX_DIMS];
  int stride[MAX_DIMS];
};
rusty1s's avatar
rusty1s committed
23
24
25
26

#define KERNEL_LOOP(I, N) \
  for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#define KERNEL_RUN(NAME, DIMS, N, ...) { \
  int grid = GET_BLOCKS(N); \
  cudaStream_t stream = THCState_getCurrentStream(state); \
  switch (DIMS) { \
    case  1: NAME<real,  1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
    case  2: NAME<real,  2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
    case  3: NAME<real,  3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
    default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
  } \
  THCudaCheck(cudaGetLastError()); \
}

static inline __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
static inline __device__ bool eq( int8_t a,  int8_t b) { return a == b; }
static inline __device__ bool eq(int16_t a, int16_t b) { return a == b; }
static inline __device__ bool eq(int32_t a, int32_t b) { return a == b; }
static inline __device__ bool eq(int64_t a, int64_t b) { return a == b; }
static inline __device__ bool eq(  float a,   float b) { return a == b; }
static inline __device__ bool eq( double a,  double b) { return a == b; }
rusty1s's avatar
typo  
rusty1s committed
46
#ifdef CUDA_HALF_TENSOR
rusty1s's avatar
rusty1s committed
47
static inline __device__ bool eq(half a, half b) { return __half2float(a) == __half2float(b); }
rusty1s's avatar
typo  
rusty1s committed
48
#endif