Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
c30df2a1
Unverified
Commit
c30df2a1
authored
Nov 25, 2025
by
Wenhao Xie
Committed by
GitHub
Nov 25, 2025
Browse files
[Enhancement] Support more dtype in `T.print` (#1329)
* [Enhancement] Support more dtype in `T.print` * upd * upd
parent
caa6dd3f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
267 deletions
+107
-267
src/tl_templates/cuda/debug.h
src/tl_templates/cuda/debug.h
+89
-264
testing/python/debug/test_tilelang_debug_print.py
testing/python/debug/test_tilelang_debug_print.py
+18
-3
No files found.
src/tl_templates/cuda/debug.h
View file @
c30df2a1
...
@@ -5,282 +5,107 @@
...
@@ -5,282 +5,107 @@
#endif
#endif
#include "common.h"
#include "common.h"
#ifndef __CUDACC_RTC__
#ifndef __CUDACC_RTC__
#include <cstdint>
#include <cstdio>
#include <cstdio>
#endif
#endif
// Template declaration for device-side debug printing (variable only)
template
<
typename
T
>
struct
PrintTraits
{
template
<
typename
T
>
__device__
void
debug_print_var
(
const
char
*
msg
,
T
var
);
static
__device__
void
print_var
(
const
char
*
msg
,
T
val
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
// Overload for pointer type (supports any cv-qualified T*)
"dtype=unknown value=%p
\n
"
,
template
<
typename
T
>
__device__
void
debug_print_var
(
const
char
*
msg
,
T
*
var
)
{
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
printf
(
threadIdx
.
z
,
(
const
void
*
)
&
val
);
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer "
}
"value=%p
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for signed char type
template
<
>
__device__
void
debug_print_var
<
signed
char
>
(
const
char
*
msg
,
signed
char
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
"char "
"value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for plain char type
template
<
>
__device__
void
debug_print_var
<
char
>
(
const
char
*
msg
,
char
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char "
"value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
(
int
)
var
);
}
// Specialization for unsigned char type
template
<
>
__device__
void
debug_print_var
<
unsigned
char
>
(
const
char
*
msg
,
unsigned
char
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned char "
"value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for integer type
template
<
>
__device__
void
debug_print_var
<
int
>
(
const
char
*
msg
,
int
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for unsigned integer type
template
<
>
__device__
void
debug_print_var
<
unsigned
int
>
(
const
char
*
msg
,
unsigned
int
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%u
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for bool type
template
<
>
__device__
void
debug_print_var
<
bool
>
(
const
char
*
msg
,
bool
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
"value=%s
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
?
"true"
:
"false"
);
}
// Specialization for float type
template
<
>
__device__
void
debug_print_var
<
float
>
(
const
char
*
msg
,
float
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
var
);
}
// Specialization for half type
template
<
>
__device__
void
debug_print_var
<
half
>
(
const
char
*
msg
,
half
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
"value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
(
float
)
var
);
}
// Specialization for half_t type
template
<
>
__device__
void
debug_print_var
<
half_t
>
(
const
char
*
msg
,
half_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
"value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
(
float
)
var
);
}
// Specialization for bfloat16_t type
static
__device__
void
print_buffer
(
const
char
*
msg
,
const
char
*
buf_name
,
template
<
>
int
index
,
T
val
)
{
__device__
void
debug_print_var
<
bfloat16_t
>
(
const
char
*
msg
,
bfloat16_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"index=%d, dtype=unknown value=%p
\n
"
,
"dtype=bfloat16_t value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
const
void
*
)
&
val
);
threadIdx
.
z
,
(
float
)
var
);
}
}
};
#define DEFINE_PRINT_TRAIT(TYPE, NAME, FORMAT, CAST_TYPE) \
template <> struct PrintTraits<TYPE> { \
static __device__ void print_var(const char *msg, TYPE val) { \
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \
"dtype=" NAME " value=" FORMAT "\n", \
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \
threadIdx.y, threadIdx.z, (CAST_TYPE)val); \
} \
static __device__ void print_buffer(const char *msg, const char *buf_name, \
int index, TYPE val) { \
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \
"buffer=%s, index=%d, dtype=" NAME " value=" FORMAT "\n", \
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \
threadIdx.y, threadIdx.z, buf_name, index, (CAST_TYPE)val); \
} \
}
// Specialization for double type
DEFINE_PRINT_TRAIT
(
char
,
"char"
,
"%d"
,
int
);
template
<
>
DEFINE_PRINT_TRAIT
(
signed
char
,
"signed char"
,
"%d"
,
int
);
__device__
void
debug_print_var
<
double
>
(
const
char
*
msg
,
double
var
)
{
DEFINE_PRINT_TRAIT
(
unsigned
char
,
"unsigned char"
,
"%u"
,
unsigned
int
);
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
DEFINE_PRINT_TRAIT
(
short
,
"short"
,
"%d"
,
int
);
"value=%lf
\n
"
,
DEFINE_PRINT_TRAIT
(
unsigned
short
,
"unsigned short"
,
"%u"
,
unsigned
int
);
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
DEFINE_PRINT_TRAIT
(
int
,
"int"
,
"%d"
,
int
);
threadIdx
.
z
,
var
);
DEFINE_PRINT_TRAIT
(
unsigned
int
,
"uint"
,
"%u"
,
unsigned
int
);
}
DEFINE_PRINT_TRAIT
(
long
,
"long"
,
"%ld"
,
long
);
DEFINE_PRINT_TRAIT
(
unsigned
long
,
"ulong"
,
"%lu"
,
unsigned
long
);
DEFINE_PRINT_TRAIT
(
long
long
,
"long long"
,
"%lld"
,
long
long
);
DEFINE_PRINT_TRAIT
(
float
,
"float"
,
"%f"
,
float
);
DEFINE_PRINT_TRAIT
(
double
,
"double"
,
"%lf"
,
double
);
DEFINE_PRINT_TRAIT
(
half
,
"half"
,
"%f"
,
float
);
DEFINE_PRINT_TRAIT
(
half_t
,
"half_t"
,
"%f"
,
float
);
DEFINE_PRINT_TRAIT
(
bfloat16_t
,
"bfloat16_t"
,
"%f"
,
float
);
#if __CUDA_ARCH_LIST__ >= 890
#if __CUDA_ARCH_LIST__ >= 890
// Specialization for fp8_e4_t type
DEFINE_PRINT_TRAIT
(
fp8_e4_t
,
"fp8_e4_t"
,
"%f"
,
float
);
template
<
>
DEFINE_PRINT_TRAIT
(
fp8_e5_t
,
"fp8_e5_t"
,
"%f"
,
float
);
__device__
void
debug_print_var
<
fp8_e4_t
>
(
const
char
*
msg
,
fp8_e4_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
"value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
(
float
)
var
);
}
// Specialization for fp8_e5_t type
template
<
>
__device__
void
debug_print_var
<
fp8_e5_t
>
(
const
char
*
msg
,
fp8_e5_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
"value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
(
float
)
var
);
}
#endif
#endif
// Template declaration for device-side debug printing (buffer only)
template
<
>
struct
PrintTraits
<
bool
>
{
template
<
typename
T
>
static
__device__
void
print_var
(
const
char
*
msg
,
bool
val
)
{
__device__
void
debug_print_buffer_value
(
const
char
*
msg
,
const
char
*
buf_name
,
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
int
index
,
T
var
);
"value=%s
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
// Specialization for signed char type
threadIdx
.
z
,
val
?
"true"
:
"false"
);
template
<
>
}
__device__
void
static
__device__
void
print_buffer
(
const
char
*
msg
,
const
char
*
buf_name
,
debug_print_buffer_value
<
signed
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
bool
val
)
{
int
index
,
signed
char
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=bool value=%s
\n
"
,
"index=%d, dtype=signed char value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
val
?
"true"
:
"false"
);
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
}
};
// Specialization for unsigned char type
template
<
typename
T
>
struct
PrintTraits
<
T
*>
{
template
<
>
static
__device__
void
print_var
(
const
char
*
msg
,
T
*
val
)
{
__device__
void
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
debug_print_buffer_value
<
unsigned
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
"dtype=pointer value=%p
\n
"
,
int
index
,
unsigned
char
var
)
{
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
threadIdx
.
z
,
(
void
*
)
val
);
"index=%d, dtype=char value=%d
\n
"
,
}
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
static
__device__
void
print_buffer
(
const
char
*
msg
,
const
char
*
buf_name
,
threadIdx
.
z
,
buf_name
,
index
,
var
);
int
index
,
T
*
val
)
{
}
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=pointer value=%p
\n
"
,
// Specialization for integer type
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
template
<
>
threadIdx
.
z
,
buf_name
,
index
,
(
void
*
)
val
);
__device__
void
debug_print_buffer_value
<
int
>
(
const
char
*
msg
,
}
const
char
*
buf_name
,
int
index
,
};
int
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
// Specialization for unsigned integer type
template
<
>
__device__
void
debug_print_buffer_value
<
unsigned
int
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
unsigned
int
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%u
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
// Specialization for float type
template
<
>
__device__
void
debug_print_buffer_value
<
float
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
float
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
// Specialization for half type
template
<
>
__device__
void
debug_print_buffer_value
<
half
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
half
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
// Specialization for half_t type
template
<
>
__device__
void
debug_print_buffer_value
<
half_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
half_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
// Specialization for bfloat16_t type
template
<
>
__device__
void
debug_print_buffer_value
<
bfloat16_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
bfloat16_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=bfloat16_t value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
// Specialization for double type
template
<
>
__device__
void
debug_print_buffer_value
<
double
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
double
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
// Specialization for fp8_e4_t type
#if __CUDA_ARCH_LIST__ >= 890
template
<
>
__device__
void
debug_print_buffer_value
<
fp8_e4_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
fp8_e4_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e4_t value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
// Specialization for fp8_e5_t type
template
<
typename
T
>
__device__
void
debug_print_var
(
const
char
*
msg
,
T
var
)
{
template
<
>
PrintTraits
<
T
>::
print_var
(
msg
,
var
);
__device__
void
debug_print_buffer_value
<
fp8_e5_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
fp8_e5_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e5_t value=%f
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
float
)
var
);
}
}
#endif
template
<
typename
T
>
__device__
void
debug_print_buffer_value
(
const
char
*
msg
,
const
char
*
buf_name
,
// Specialization for int16 type
int
index
,
T
var
)
{
template
<
>
PrintTraits
<
T
>::
print_buffer
(
msg
,
buf_name
,
index
,
var
);
__device__
void
debug_print_buffer_value
<
int16_t
>
(
const
char
*
msg
,
const
char
*
buf_name
,
int
index
,
int16_t
var
)
{
printf
(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int16_t value=%d
\n
"
,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
int32_t
)
var
);
}
}
TL_DEVICE
void
device_assert
(
bool
cond
)
{
assert
(
cond
);
}
TL_DEVICE
void
device_assert
(
bool
cond
)
{
assert
(
cond
);
}
...
@@ -290,4 +115,4 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
...
@@ -290,4 +115,4 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
printf
(
"Device assert failed: %s
\n
"
,
msg
);
printf
(
"Device assert failed: %s
\n
"
,
msg
);
assert
(
0
);
assert
(
0
);
}
}
}
}
\ No newline at end of file
testing/python/debug/test_tilelang_debug_print.py
View file @
c30df2a1
...
@@ -19,9 +19,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
...
@@ -19,9 +19,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def
test_debug_print_buffer
():
def
test_debug_print_buffer
():
debug_print_buffer
(
16
,
16
,
dtype
=
"float"
)
debug_print_buffer
(
dtype
=
'bool'
)
debug_print_buffer
(
16
,
16
,
dtype
=
"float16"
)
debug_print_buffer
(
dtype
=
'int8'
)
debug_print_buffer
(
16
,
16
,
dtype
=
"uint8"
)
debug_print_buffer
(
dtype
=
'int16'
)
debug_print_buffer
(
dtype
=
'int32'
)
debug_print_buffer
(
dtype
=
'int64'
)
debug_print_buffer
(
dtype
=
'uint8'
)
debug_print_buffer
(
dtype
=
'uint16'
)
debug_print_buffer
(
dtype
=
'uint32'
)
debug_print_buffer
(
dtype
=
'uint64'
)
debug_print_buffer
(
dtype
=
'float16'
)
debug_print_buffer
(
dtype
=
'float32'
)
debug_print_buffer
(
dtype
=
'float64'
)
debug_print_buffer
(
dtype
=
'bfloat16'
)
debug_print_buffer
(
dtype
=
'float8_e4m3'
)
debug_print_buffer
(
dtype
=
'float8_e4m3fn'
)
debug_print_buffer
(
dtype
=
'float8_e4m3fnuz'
)
debug_print_buffer
(
dtype
=
'float8_e5m2'
)
debug_print_buffer
(
dtype
=
'float8_e5m2fnuz'
)
def
debug_print_buffer_conditional
(
M
=
16
,
N
=
16
):
def
debug_print_buffer_conditional
(
M
=
16
,
N
=
16
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment