common.h 12.2 KB
Newer Older
1
2
#pragma once

3
#ifndef __CUDACC_RTC__
4
#include <cuda_runtime.h>
5
6
#endif

7
#include "atomic.h"
8
#include <cute/arch/util.hpp>
9
10
11
12
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>

13
14
15
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

16
17
18
19
using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;

20
21
using cute::cast_smem_ptr_to_uint;

22
23
using int4_t = int4;

24
25
26
#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
27
28
#define hsin cutlass::fast_sin
#define hcos cutlass::fast_cos
29
30
31
32
33
34
35
36
#define htanh cutlass::fast_tanh
#define hpow powf

#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short

#define TL_DEVICE __forceinline__ __device__
37
#define TL_DEVICE_NOINLINE __noinline__ __device__
38
39
#define TL_PATCH

40
41
42
43
44
45
46
47
48
49
#define TILELANG_CHECK(stmt)                                                   \
  do {                                                                         \
    cudaError_t __err = (stmt);                                                \
    if (__err != cudaSuccess) {                                                \
      snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__,          \
               __LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err));  \
      return -1;                                                               \
    }                                                                          \
  } while (0)

50
51
52
53
#define TILELANG_CHECK_LAST_ERROR(kernel_name)                                 \
  do {                                                                         \
    cudaError_t __err = cudaGetLastError();                                    \
    if (__err != cudaSuccess) {                                                \
54
      snprintf(error_buf, ERROR_BUF_SIZE, kernel_name ": %s - %s",             \
55
56
57
58
59
               cudaGetErrorName(__err), cudaGetErrorString(__err));            \
      return -1;                                                               \
    }                                                                          \
  } while (0)

60
61
62
63
64
// using cutlass abs function for half_t
TL_PATCH TL_DEVICE half_t __habs(const half_t x) { return abs(x); }

// using cutlass abs function for bfloat_t
TL_PATCH TL_DEVICE bfloat16_t __habs(const bfloat16_t x) { return abs(x); }
65

66
67
68
69
70
// hrsqrt function for half_t
TL_PATCH TL_DEVICE half_t hrsqrt(const half_t x) {
  return half_t(hrsqrt(x.to_half()));
}

71
72
// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
73
74
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
75
76
77
78
79
  return (v1 << 16) | v0;
}

// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
80
81
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
82
83
84
85
86
  return (v1 << 16) | v0;
}

// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
87
88
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
89
90
91
  return (v1 << 16) | v0;
}

92
93
94
95
96
97
98
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) {
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
  return (v1 << 16) | v0;
}

99
// Pack four char values.
100
101
102
103
104
TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2,
                       signed char x3) {
  return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0;
}

105
106
107
108
109
110
111
112
113
114
// Pack eight char values.
TL_DEVICE int2 make_int2(signed char x0, signed char x1, signed char x2,
                         signed char x3, signed char y0, signed char y1,
                         signed char y2, signed char y3) {
  int2 result;
  result.x = make_int(x0, x1, x2, x3);
  result.y = make_int(y0, y1, y2, y3);
  return result;
}

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Pack sixteen char values.
TL_DEVICE int4_t make_int4(signed char x0, signed char x1, signed char x2,
                           signed char x3, signed char y0, signed char y1,
                           signed char y2, signed char y3, signed char z0,
                           signed char z1, signed char z2, signed char z3,
                           signed char w0, signed char w1, signed char w2,
                           signed char w3) {
  int4_t result;
  result.x = make_int(x0, x1, x2, x3);
  result.y = make_int(y0, y1, y2, y3);
  result.z = make_int(z0, z1, z2, z3);
  result.w = make_int(w0, w1, w2, w3);
  return result;
}

130
131
132
133
134
135
136
137
138
139
140
// Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
                                   int z1, int w0, int w1) {
  longlong4 result;
  *((int2 *)&result.x) = make_int2(x0, x1);
  *((int2 *)&result.y) = make_int2(y0, y1);
  *((int2 *)&result.z) = make_int2(z0, z1);
  *((int2 *)&result.w) = make_int2(w0, w1);
  return result;
}

141
// Helper to cast SMEM pointer to unsigned
142
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
143
144
145
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}

146
147
148
149
150
151
152
153
154
155
156
157
158
/**
 * Convert a shared-memory pointer to a 32-bit unsigned integer address.
 *
 * Casts the given pointer (expected to reference shared memory) into a 32-bit
 * unsigned integer using the device address-space conversion required for
 * shared-memory pointers.
 *
 * @param smem_ptr Pointer into shared memory.
 * @return 32-bit unsigned integer representation of the shared-memory address.
 *
 * @note The pointer must refer to shared memory; behavior is undefined for
 *       pointers in other address spaces.
 */
159
160
161
162
163
164
165
166
167
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
  unsigned int smem_int;
  asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
               "cvt.u32.u64 %0, smem_int; }"
               : "=r"(smem_int)
               : "l"(smem_ptr));
  return smem_int;
}

168
// DP4A
169
template <typename InDatatype, typename OutDatatype>
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
TL_DEVICE /**
           * Compute a 4×8-bit dot-product-accumulate using the CUDA DP4A
           * intrinsic.
           *
           * Reads 32-bit packed values from `a` and `b` (each containing four
           * signed 8-bit lanes), applies the __dp4a operation (dot product of
           * the four lane pairs added to an accumulator), and stores the 32-bit
           * integer result through `c`.
           *
           * @param a Pointer to a 32-bit packed input containing four signed
           * 8-bit elements.
           * @param b Pointer to a 32-bit packed input containing four signed
           * 8-bit elements.
           * @param c Pointer to a 32-bit accumulator; its current value is used
           * as the initial accumulator and overwritten with the resulting int32
           * sum.
           */
    void
    DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
189
190
191
  const int a_int = *((int *)a);
  const int b_int = *((int *)b);
  const int c_int = *((int *)c);
192
193
  *c = __dp4a(a_int, b_int, c_int);
}
194
195

namespace tl {
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
/*!
 * \brief PTX data type.
 * \note
 * PTX fundamental data types:
 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
 * PTX matrix data types:
 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
 */
enum class DataType : int {
  kInt4 = 0,
  kUInt4 = 1,
  kInt8 = 2,
  kUInt8 = 3,
  kInt16 = 4,
  kUInt16 = 5,
  kInt32 = 6,
  kUInt32 = 7,
  kInt64 = 8,
  kUInt64 = 9,
  kFloat8_e4m3 = 10,
  kFloat8_e5m2 = 11,
  kFloat16 = 12,
  kBFloat16 = 13,
  kFloat16x2 = 14,
  kFloat32 = 15,
  kTensorFloat32 = 16,
  kFloat64 = 17,
  kBit1 = 18,
  kBit8 = 19,
  kBit16 = 20,
  kBit32 = 21,
  kBit64 = 22
};

union GmmaDescriptor {
  CUTE_HOST_DEVICE constexpr GmmaDescriptor() noexcept : desc_(0) {}
  CUTE_HOST_DEVICE constexpr GmmaDescriptor(uint64_t desc) noexcept
      : desc_(desc) {}
  CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept
      : desc_(t.desc_) {}
  CUTE_HOST_DEVICE constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept
      : desc_(t.desc_) {}

  CUTE_HOST_DEVICE constexpr GmmaDescriptor &
  operator=(GmmaDescriptor const &t) noexcept {
    desc_ = t.desc_;
    return *this;
  }

  CUTE_HOST_DEVICE constexpr GmmaDescriptor &
  operator=(GmmaDescriptor &&t) noexcept {
    desc_ = t.desc_;
    return *this;
  }

  uint64_t desc_;
  uint32_t reg32_[2];
  uint16_t reg16_[4];

  // Bitfield implementation avoids the need for shifts in assignment
  struct {
    // start_address, bit [0,14), 4LSB not included
    uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused
    // leading dimension byte offset, bit [16,30), 4LSB not included
    // For N: This is the stride from the first col to the second col of the 8x2
    // brick in INTERLEAVED
    //   Unused for all SWIZZLE_* layouts (and assumed to be 1)
    // For T: This is the stride from the first 8 rows to the next 8 rows.
    uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
    // stride dimension byte offset, bit [32,46), 4LSB not included
    // For N: This is the stride from the first 8 rows to the next 8 rows.
    // For T: This is the stride fro mthe first 8 cols to the next 8 cols.
    uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused
    // base_offset, bit [49,52)
    // Valid only for SWIZZLE_128B and SWIZZLE_64B
271
272
    uint8_t : 1, base_offset_ : 3,
        : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    // layout type, bit [62,64)
    // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1
    uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8)
  } bitfield;

  // Decay to a uint64_t
  CUTE_HOST_DEVICE constexpr operator uint64_t() const noexcept {
    return desc_;
  }
  template <typename T>
  CUTE_HOST_DEVICE constexpr GmmaDescriptor operator+(const T &offset) const {
    GmmaDescriptor ret;
    ret.reg32_[0] = reg32_[0] + uint32_t(offset);
    ret.reg32_[1] = reg32_[1];
    return ret;
  }
};

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
// Any
template <typename T> TL_DEVICE bool Any(T *a, int size) {
  for (int i = 0; i < size; i++) {
    if (a[i]) {
      return true;
    }
  }
  return false;
}

// All
template <typename T> TL_DEVICE bool All(T *a, int size) {
  for (int i = 0; i < size; i++) {
    if (!a[i]) {
      return false;
    }
  }
  return true;
}
310
311
312
313
314
315
316
317
318
319

// Pow of int
template <int y = 1, typename T> TL_DEVICE T pow_of_int(T x) {
  T result = x;
  for (int i = 1; i < y; i++) {
    result *= x;
  }
  return result;
}

320
321
322
323
324
325
// Thread partial barrier synchronization
// https://docs.nvidia.com/cuda/parallel-thread-execution/#memory-consistency-model
template <int barrier_id = 0, int thread_count = 0>
TL_DEVICE void __sync_thread_partial() {
  asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
}
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

template <int layout_type = 0, int leading_byte_offset = 0,
          int stride_byte_offset = 0, typename T>
TL_DEVICE void initialize_descriptor(GmmaDescriptor &descriptor,
                                     T *start_address) {
  descriptor.bitfield.start_address_ =
      cute::cast_smem_ptr_to_uint(start_address) >> 4;
  descriptor.bitfield.layout_type_ = layout_type;
  descriptor.bitfield.base_offset_ = 0;
  descriptor.bitfield.leading_byte_offset_ = leading_byte_offset;
  descriptor.bitfield.stride_byte_offset_ = stride_byte_offset;
}

template <typename T>
TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
                                          T offset) {
  descriptor.reg32_[0] += (offset >> 4);
}

345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
  using cute::float_e4m3_t::float_e4m3_t;
  CUTLASS_HOST_DEVICE
  float_e4m3_t() = default;

  CUTLASS_HOST_DEVICE
  explicit float_e4m3_t(__nv_bfloat16 x)
      : float_e4m3_t(static_cast<float>(x)) {}
};

struct float_e5m2_t : public cute::float_e5m2_t {
  using cute::float_e5m2_t::float_e5m2_t;
  CUTLASS_HOST_DEVICE
  float_e5m2_t() = default;

  CUTLASS_HOST_DEVICE
  explicit float_e5m2_t(__nv_bfloat16 x)
      : float_e5m2_t(static_cast<float>(x)) {}
};

template <typename T> struct to_cute_type {
  using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
  using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
  using type = cute::float_e5m2_t;
};

376
} // namespace tl
Wenhao Xie's avatar
Wenhao Xie committed
377
378
379
380
381

namespace cutlass {
TL_DEVICE
bfloat16_t fast_exp(bfloat16_t x) { return ::hexp(x); }
} // namespace cutlass