common.h 7.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#pragma once

#include <cuda_runtime.h>
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;

12
13
using int4_t = int4;

14
15
16
17
18
19
20
21
22
23
24
#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
#define htanh cutlass::fast_tanh
#define hpow powf

#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short

#define TL_DEVICE __forceinline__ __device__
25
#define TL_DEVICE_NOINLINE __noinline__ __device__
26
27
#define TL_PATCH

28
29
30
31
32
33
34
35
36
37
#define TILELANG_CHECK(stmt)                                                   \
  do {                                                                         \
    cudaError_t __err = (stmt);                                                \
    if (__err != cudaSuccess) {                                                \
      snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__,          \
               __LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err));  \
      return -1;                                                               \
    }                                                                          \
  } while (0)

38
39
40
41
42
43
44
45
46
47
#define TILELANG_CHECK_LAST_ERROR(kernel_name)                                 \
  do {                                                                         \
    cudaError_t __err = cudaGetLastError();                                    \
    if (__err != cudaSuccess) {                                                \
      snprintf(error_buf, ERROR_BUF_SIZE, "kernel_name: %s - %s",              \
               cudaGetErrorName(__err), cudaGetErrorString(__err));            \
      return -1;                                                               \
    }                                                                          \
  } while (0)

48
49
50
51
52
// abs function for bfloat_t and half_t since there is no implicit convertion
// method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
  return half_t(__habs(x.to_half()));
}
53
54
55

// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
56
57
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
58
59
60
61
62
  return (v1 << 16) | v0;
}

// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
63
64
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
65
66
67
68
69
  return (v1 << 16) | v0;
}

// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
70
71
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
72
73
74
  return (v1 << 16) | v0;
}

75
76
77
78
79
80
81
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) {
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
  return (v1 << 16) | v0;
}

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Pack four char values
TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2,
                       signed char x3) {
  return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0;
}

// Pack sixteen char values.
TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
                           signed char x3, signed char y0, signed char y1,
                           signed char y2, signed char y3, signed char z0,
                           signed char z1, signed char z2, signed char z3,
                           signed char w0, signed char w1, signed char w2,
                           signed char w3) {
  int4_t result;
  result.x = make_int(x0, x1, x2, x3);
  result.y = make_int(y0, y1, y2, y3);
  result.z = make_int(z0, z1, z2, z3);
  result.w = make_int(w0, w1, w2, w3);
  return result;
}

103
// Helper to cast SMEM pointer to unsigned
104
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
105
106
107
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}

108
109
110
111
112
113
114
115
116
117
// Helper to cast SMEM pointer to unsigned
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
  unsigned int smem_int;
  asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
               "cvt.u32.u64 %0, smem_int; }"
               : "=r"(smem_int)
               : "l"(smem_ptr));
  return smem_int;
}

118
119
120
121
122
123
124
125
126
127
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 *address, T2 val) {
  atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}

// // AtomicAdd Functions for FP32
// TL_DEVICE void AtomicAdd(float *address, float val) {
//   atomicAdd(reinterpret_cast<float *>(address), val);
// }

128
// AtomicAdd Functions for FP16
129
template <> TL_DEVICE void AtomicAdd(half_t *address, half_t val) {
130
  // Use atomicCAS with built-in cuda_fp16 support
131
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
132
133
134
}

// AtomicAdd Functions for FP16
135
template <> TL_DEVICE void AtomicAdd(half_t *address, half_t *val) {
136
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
137
138
}

139
// AtomicAdd Functions for FP16
140
template <> TL_DEVICE void AtomicAdd(half_t *address, float val) {
141
  // Use atomicCAS with built-in cuda_fp16 support
142
  atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
143
144
}

145
146
// AtomicAdd Functions for BFLOAT16
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
147
// AtomicAdd Functions for BFLOAT16
148
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) {
149
150
151
152
  atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
            static_cast<__nv_bfloat16>(*val));
}

153
154
// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) {
155
156
157
  atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val));
}

158
159
#endif

160
161
162
163
164
165
// AtomicAdd Functions for FP16x2
TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
  atomicAdd(reinterpret_cast<half2 *>(address),
            static_cast<half2>(*reinterpret_cast<half2 *>(val)));
}

166
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
167
168
169
170
171
172
173

// AtomicAdd Functions for BFLOAT16
template <> TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) {
  atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
            static_cast<__nv_bfloat16>(val));
}

174
175
176
177
178
179
// AtomicAdd Functions for BFLOAT16x2
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
  atomicAdd(
      reinterpret_cast<__nv_bfloat162 *>(address),
      static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
180
#endif
181

182
183
184
185
186
187
188
189
190
191
192
193
194
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
// AtomicAdd Functions for FLOAT16x2
TL_DEVICE void AtomicAddx2(float *address, float *val) {
  atomicAdd(reinterpret_cast<float2 *>(address),
            static_cast<float2>(*reinterpret_cast<float2 *>(val)));
}
// AtomicAdd Functions for FLOAT16x4
TL_DEVICE void AtomicAddx4(float *address, float *val) {
  atomicAdd(reinterpret_cast<float4 *>(address),
            static_cast<float4>(*reinterpret_cast<float4 *>(val)));
}
#endif

195
// DP4A
196
197
198
199
200
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
  const int a_int = *((int *)a);
  const int b_int = *((int *)b);
  const int c_int = *((int *)c);
201
202
  *c = __dp4a(a_int, b_int, c_int);
}
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223

namespace tl {
// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
  for (int i = 0; i < size; i++) {
    if (a[i]) {
      return true;
    }
  }
  return false;
}

// All
template <typename T> TL_DEVICE bool All(T *a, int size) {
  for (int i = 0; i < size; i++) {
    if (!a[i]) {
      return false;
    }
  }
  return true;
}
224
225
226
227
228
229
230
231
232
233

// Pow of int
template <int y = 1, typename T> TL_DEVICE T pow_of_int(T x) {
  T result = x;
  for (int i = 1; i < y; i++) {
    result *= x;
  }
  return result;
}

234
} // namespace tl