debug.h 11.3 KB
Newer Older
1
2
#pragma once

3
#if __CUDA_ARCH_LIST__ >= 890
4
#include "./cuda_fp8.h"
5
6
#endif

7
#include "common.h"
8
9
10
11

#ifndef __CUDACC_RTC__
#include <cstdio>
#endif
12
13

// Template declaration for device-side debug printing (variable only)
14
15
template <typename T> __device__ void debug_print_var(const char *msg, T var);

16
17
18
19
20
21
22
23
24
// Overload for pointer type (supports any cv-qualified T*)
template <typename T> __device__ void debug_print_var(const char *msg, T *var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer "
      "value=%p\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, var);
}

25
26
27
28
29
30
31
32
33
34
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
         "char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

35
36
37
38
39
40
41
42
// Specialization for plain char type
template <> __device__ void debug_print_var<char>(const char *msg, char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (int)var);
}

43
44
45
46
47
48
49
50
51
52
// Specialization for unsigned char type
template <>
__device__ void debug_print_var<unsigned char>(const char *msg,
                                               unsigned char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
         "dtype=unsigned char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}
53
54

// Specialization for integer type
55
template <> __device__ void debug_print_var<int>(const char *msg, int var) {
56
57
58
59
60
61
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

62
63
64
65
66
67
68
69
70
71
// Specialization for unsigned integer type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
                                              unsigned int var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
         "value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

72
73
74
75
76
77
78
79
// Specialization for bool type
template <> __device__ void debug_print_var<bool>(const char *msg, bool var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
         "value=%s\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var ? "true" : "false");
}

80
// Specialization for float type
81
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
82
83
84
85
86
87
88
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

// Specialization for half type
89
template <> __device__ void debug_print_var<half>(const char *msg, half var) {
90
91
92
93
94
95
96
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for half_t type
97
98
template <>
__device__ void debug_print_var<half_t>(const char *msg, half_t var) {
99
100
101
102
103
104
105
106
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for bfloat16_t type
template <>
107
__device__ void debug_print_var<bfloat16_t>(const char *msg, bfloat16_t var) {
108
109
110
111
112
113
114
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
         "dtype=bfloat16_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for double type
115
116
template <>
__device__ void debug_print_var<double>(const char *msg, double var) {
117
118
119
120
121
122
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
         "value=%lf\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

123
#if __CUDA_ARCH_LIST__ >= 890
124
125
126
127
128
129
130
131
132
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
      "value=%f\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, (float)var);
}
133

134
135
136
137
138
139
140
141
142
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
      "value=%f\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, (float)var);
}
143

144
145
#endif

146
147
// Template declaration for device-side debug printing (buffer only)
template <typename T>
148
149
150
151
152
153
154
155
156
157
158
159
160
161
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
                                         int index, T var);

// Specialization for signed char type
template <>
__device__ void
debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
                                      int index, signed char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=signed char value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}

Gabriel Wu's avatar
Gabriel Wu committed
162
// Specialization for unsigned char type
163
template <>
164
165
166
__device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
                                        int index, unsigned char var) {
167
168
169
170
171
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=char value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}
172
173
174

// Specialization for integer type
template <>
175
176
177
__device__ void debug_print_buffer_value<int>(const char *msg,
                                              const char *buf_name, int index,
                                              int var) {
178
179
180
181
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
182
183
184
185
186
187
188
189
190
191
192
}

// Specialization for unsigned integer type
template <>
__device__ void
debug_print_buffer_value<unsigned int>(const char *msg, const char *buf_name,
                                       int index, unsigned int var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
193
194
195
196
}

// Specialization for float type
template <>
197
198
199
__device__ void debug_print_buffer_value<float>(const char *msg,
                                                const char *buf_name, int index,
                                                float var) {
200
201
202
203
204
205
206
207
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=float value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}

// Specialization for half type
template <>
208
209
210
__device__ void debug_print_buffer_value<half>(const char *msg,
                                               const char *buf_name, int index,
                                               half var) {
211
212
213
214
215
216
217
218
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=half value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for half_t type
template <>
219
220
__device__ void debug_print_buffer_value<half_t>(const char *msg,
                                                 const char *buf_name,
221
222
223
224
225
226
227
228
229
                                                 int index, half_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=half_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for bfloat16_t type
template <>
230
231
232
__device__ void
debug_print_buffer_value<bfloat16_t>(const char *msg, const char *buf_name,
                                     int index, bfloat16_t var) {
233
234
235
236
237
238
239
240
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=bfloat16_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for double type
template <>
241
242
__device__ void debug_print_buffer_value<double>(const char *msg,
                                                 const char *buf_name,
243
244
245
246
247
248
                                                 int index, double var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=double value=%lf\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}
249
250

// Specialization for fp8_e4_t type
251
#if __CUDA_ARCH_LIST__ >= 890
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
                                                   const char *buf_name,
                                                   int index, fp8_e4_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=fp8_e4_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
                                                   const char *buf_name,
                                                   int index, fp8_e5_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=fp8_e5_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}
272

273
274
#endif

275
276
277
278
279
280
281
282
283
284
// Specialization for int16 type
template <>
__device__ void debug_print_buffer_value<int16_t>(const char *msg,
                                                  const char *buf_name,
                                                  int index, int16_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int16_t value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (int32_t)var);
}
285
286
287
288
289
290
291
292
293

TL_DEVICE void device_assert(bool cond) { assert(cond); }

TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
  if (!cond) {
    printf("Device assert failed: %s\n", msg);
    assert(0);
  }
}