debug.h 11.2 KB
Newer Older
1
2
#pragma once

3
#include "./cuda_fp8.h"
4
#include "common.h"
5
6
7
8

#ifndef __CUDACC_RTC__
#include <cstdio>
#endif
9
10

// Template declaration for device-side debug printing (variable only)
11
12
template <typename T> __device__ void debug_print_var(const char *msg, T var);

13
14
15
16
17
18
19
20
21
// Overload for pointer type (supports any cv-qualified T*)
template <typename T> __device__ void debug_print_var(const char *msg, T *var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer "
      "value=%p\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, var);
}

22
23
24
25
26
27
28
29
30
31
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
         "char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

32
33
34
35
36
37
38
39
// Specialization for plain char type
template <> __device__ void debug_print_var<char>(const char *msg, char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (int)var);
}

40
41
42
43
44
45
46
47
48
49
// Specialization for unsigned char type
template <>
__device__ void debug_print_var<unsigned char>(const char *msg,
                                               unsigned char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
         "dtype=unsigned char "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}
50
51

// Specialization for integer type
52
template <> __device__ void debug_print_var<int>(const char *msg, int var) {
53
54
55
56
57
58
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
         "value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

59
60
61
62
63
64
65
66
67
68
// Specialization for unsigned integer type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
                                              unsigned int var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
         "value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

69
70
71
72
73
74
75
76
// Specialization for bool type
template <> __device__ void debug_print_var<bool>(const char *msg, bool var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
         "value=%s\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var ? "true" : "false");
}

77
// Specialization for float type
78
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
79
80
81
82
83
84
85
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

// Specialization for half type
86
template <> __device__ void debug_print_var<half>(const char *msg, half var) {
87
88
89
90
91
92
93
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for half_t type
94
95
template <>
__device__ void debug_print_var<half_t>(const char *msg, half_t var) {
96
97
98
99
100
101
102
103
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
         "value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for bfloat16_t type
template <>
104
__device__ void debug_print_var<bfloat16_t>(const char *msg, bfloat16_t var) {
105
106
107
108
109
110
111
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
         "dtype=bfloat16_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, (float)var);
}

// Specialization for double type
112
113
template <>
__device__ void debug_print_var<double>(const char *msg, double var) {
114
115
116
117
118
119
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
         "value=%lf\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, var);
}

120
121
122
123
124
125
126
127
128
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
      "value=%f\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, (float)var);
}
129

130
131
132
133
134
135
136
137
138
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
  printf(
      "msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
      "value=%f\n",
      msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
      threadIdx.z, (float)var);
}
139
140
141

// Template declaration for device-side debug printing (buffer only)
template <typename T>
142
143
144
145
146
147
148
149
150
151
152
153
154
155
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
                                         int index, T var);

// Specialization for signed char type
template <>
__device__ void
debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
                                      int index, signed char var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=signed char value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}

Gabriel Wu's avatar
Gabriel Wu committed
156
// Specialization for unsigned char type
157
template <>
158
159
160
__device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
                                        int index, unsigned char var) {
161
162
163
164
165
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=char value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}
166
167
168

// Specialization for integer type
template <>
169
170
171
__device__ void debug_print_buffer_value<int>(const char *msg,
                                              const char *buf_name, int index,
                                              int var) {
172
173
174
175
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
176
177
178
179
180
181
182
183
184
185
186
}

// Specialization for unsigned integer type
template <>
__device__ void
debug_print_buffer_value<unsigned int>(const char *msg, const char *buf_name,
                                       int index, unsigned int var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int value=%u\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
187
188
189
190
}

// Specialization for float type
template <>
191
192
193
__device__ void debug_print_buffer_value<float>(const char *msg,
                                                const char *buf_name, int index,
                                                float var) {
194
195
196
197
198
199
200
201
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=float value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}

// Specialization for half type
template <>
202
203
204
__device__ void debug_print_buffer_value<half>(const char *msg,
                                               const char *buf_name, int index,
                                               half var) {
205
206
207
208
209
210
211
212
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=half value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for half_t type
template <>
213
214
__device__ void debug_print_buffer_value<half_t>(const char *msg,
                                                 const char *buf_name,
215
216
217
218
219
220
221
222
223
                                                 int index, half_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=half_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for bfloat16_t type
template <>
224
225
226
__device__ void
debug_print_buffer_value<bfloat16_t>(const char *msg, const char *buf_name,
                                     int index, bfloat16_t var) {
227
228
229
230
231
232
233
234
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=bfloat16_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for double type
template <>
235
236
__device__ void debug_print_buffer_value<double>(const char *msg,
                                                 const char *buf_name,
237
238
239
240
241
242
                                                 int index, double var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=double value=%lf\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, var);
}
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
                                                   const char *buf_name,
                                                   int index, fp8_e4_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=fp8_e4_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}

// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
                                                   const char *buf_name,
                                                   int index, fp8_e5_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=fp8_e5_t value=%f\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (float)var);
}
265
266
267
268
269
270
271
272
273
274
275

// Specialization for int16 type
template <>
__device__ void debug_print_buffer_value<int16_t>(const char *msg,
                                                  const char *buf_name,
                                                  int index, int16_t var) {
  printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
         "index=%d, dtype=int16_t value=%d\n",
         msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
         threadIdx.z, buf_name, index, (int32_t)var);
}
276
277
278
279
280
281
282
283
284

TL_DEVICE void device_assert(bool cond) { assert(cond); }

TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
  if (!cond) {
    printf("Device assert failed: %s\n", msg);
    assert(0);
  }
}