Unverified Commit c30df2a1 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Enhancement] Support more dtype in `T.print` (#1329)

* [Enhancement] Support more dtype in `T.print`

* upd

* upd
parent caa6dd3f
...@@ -5,282 +5,107 @@ ...@@ -5,282 +5,107 @@
#endif #endif
#include "common.h" #include "common.h"
#ifndef __CUDACC_RTC__ #ifndef __CUDACC_RTC__
#include <cstdint>
#include <cstdio> #include <cstdio>
#endif #endif
// Template declaration for device-side debug printing (variable only) template <typename T> struct PrintTraits {
template <typename T> __device__ void debug_print_var(const char *msg, T var); static __device__ void print_var(const char *msg, T val) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
// Overload for pointer type (supports any cv-qualified T*) "dtype=unknown value=%p\n",
template <typename T> __device__ void debug_print_var(const char *msg, T *var) { msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
printf( threadIdx.z, (const void *)&val);
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=pointer " }
"value=%p\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for signed char type
template <>
__device__ void debug_print_var<signed char>(const char *msg, signed char var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=signed "
"char "
"value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for plain char type
template <> __device__ void debug_print_var<char>(const char *msg, char var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=char "
"value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (int)var);
}
// Specialization for unsigned char type
template <>
__device__ void debug_print_var<unsigned char>(const char *msg,
unsigned char var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=unsigned char "
"value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for integer type
template <> __device__ void debug_print_var<int>(const char *msg, int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for unsigned integer type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
unsigned int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for bool type
template <> __device__ void debug_print_var<bool>(const char *msg, bool var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
"value=%s\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var ? "true" : "false");
}
// Specialization for float type
template <> __device__ void debug_print_var<float>(const char *msg, float var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for half type
template <> __device__ void debug_print_var<half>(const char *msg, half var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for half_t type
template <>
__device__ void debug_print_var<half_t>(const char *msg, half_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for bfloat16_t type static __device__ void print_buffer(const char *msg, const char *buf_name,
template <> int index, T val) {
__device__ void debug_print_var<bfloat16_t>(const char *msg, bfloat16_t var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " "index=%d, dtype=unknown value=%p\n",
"dtype=bfloat16_t value=%f\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, buf_name, index, (const void *)&val);
threadIdx.z, (float)var); }
} };
#define DEFINE_PRINT_TRAIT(TYPE, NAME, FORMAT, CAST_TYPE) \
template <> struct PrintTraits<TYPE> { \
static __device__ void print_var(const char *msg, TYPE val) { \
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \
"dtype=" NAME " value=" FORMAT "\n", \
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \
threadIdx.y, threadIdx.z, (CAST_TYPE)val); \
} \
static __device__ void print_buffer(const char *msg, const char *buf_name, \
int index, TYPE val) { \
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): " \
"buffer=%s, index=%d, dtype=" NAME " value=" FORMAT "\n", \
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \
threadIdx.y, threadIdx.z, buf_name, index, (CAST_TYPE)val); \
} \
}
// Specialization for double type DEFINE_PRINT_TRAIT(char, "char", "%d", int);
template <> DEFINE_PRINT_TRAIT(signed char, "signed char", "%d", int);
__device__ void debug_print_var<double>(const char *msg, double var) { DEFINE_PRINT_TRAIT(unsigned char, "unsigned char", "%u", unsigned int);
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double " DEFINE_PRINT_TRAIT(short, "short", "%d", int);
"value=%lf\n", DEFINE_PRINT_TRAIT(unsigned short, "unsigned short", "%u", unsigned int);
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, DEFINE_PRINT_TRAIT(int, "int", "%d", int);
threadIdx.z, var); DEFINE_PRINT_TRAIT(unsigned int, "uint", "%u", unsigned int);
} DEFINE_PRINT_TRAIT(long, "long", "%ld", long);
DEFINE_PRINT_TRAIT(unsigned long, "ulong", "%lu", unsigned long);
DEFINE_PRINT_TRAIT(long long, "long long", "%lld", long long);
DEFINE_PRINT_TRAIT(float, "float", "%f", float);
DEFINE_PRINT_TRAIT(double, "double", "%lf", double);
DEFINE_PRINT_TRAIT(half, "half", "%f", float);
DEFINE_PRINT_TRAIT(half_t, "half_t", "%f", float);
DEFINE_PRINT_TRAIT(bfloat16_t, "bfloat16_t", "%f", float);
#if __CUDA_ARCH_LIST__ >= 890 #if __CUDA_ARCH_LIST__ >= 890
// Specialization for fp8_e4_t type DEFINE_PRINT_TRAIT(fp8_e4_t, "fp8_e4_t", "%f", float);
template <> DEFINE_PRINT_TRAIT(fp8_e5_t, "fp8_e5_t", "%f", float);
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
#endif #endif
// Template declaration for device-side debug printing (buffer only) template <> struct PrintTraits<bool> {
template <typename T> static __device__ void print_var(const char *msg, bool val) {
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name, printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=bool "
int index, T var); "value=%s\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
// Specialization for signed char type threadIdx.z, val ? "true" : "false");
template <> }
__device__ void static __device__ void print_buffer(const char *msg, const char *buf_name,
debug_print_buffer_value<signed char>(const char *msg, const char *buf_name, int index, bool val) {
int index, signed char var) { printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " "index=%d, dtype=bool value=%s\n",
"index=%d, dtype=signed char value=%d\n", msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, buf_name, index, val ? "true" : "false");
threadIdx.z, buf_name, index, var); }
} };
// Specialization for unsigned char type template <typename T> struct PrintTraits<T *> {
template <> static __device__ void print_var(const char *msg, T *val) {
__device__ void printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name, "dtype=pointer value=%p\n",
int index, unsigned char var) { msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " threadIdx.z, (void *)val);
"index=%d, dtype=char value=%d\n", }
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, static __device__ void print_buffer(const char *msg, const char *buf_name,
threadIdx.z, buf_name, index, var); int index, T *val) {
} printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=pointer value=%p\n",
// Specialization for integer type msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
template <> threadIdx.z, buf_name, index, (void *)val);
__device__ void debug_print_buffer_value<int>(const char *msg, }
const char *buf_name, int index, };
int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for unsigned integer type
template <>
__device__ void
debug_print_buffer_value<unsigned int>(const char *msg, const char *buf_name,
int index, unsigned int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for float type
template <>
__device__ void debug_print_buffer_value<float>(const char *msg,
const char *buf_name, int index,
float var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for half type
template <>
__device__ void debug_print_buffer_value<half>(const char *msg,
const char *buf_name, int index,
half var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for half_t type
template <>
__device__ void debug_print_buffer_value<half_t>(const char *msg,
const char *buf_name,
int index, half_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for bfloat16_t type
template <>
__device__ void
debug_print_buffer_value<bfloat16_t>(const char *msg, const char *buf_name,
int index, bfloat16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=bfloat16_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for double type
template <>
__device__ void debug_print_buffer_value<double>(const char *msg,
const char *buf_name,
int index, double var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for fp8_e4_t type
#if __CUDA_ARCH_LIST__ >= 890
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
const char *buf_name,
int index, fp8_e4_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e4_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for fp8_e5_t type template <typename T> __device__ void debug_print_var(const char *msg, T var) {
template <> PrintTraits<T>::print_var(msg, var);
__device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
const char *buf_name,
int index, fp8_e5_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e5_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
} }
#endif template <typename T>
__device__ void debug_print_buffer_value(const char *msg, const char *buf_name,
// Specialization for int16 type int index, T var) {
template <> PrintTraits<T>::print_buffer(msg, buf_name, index, var);
__device__ void debug_print_buffer_value<int16_t>(const char *msg,
const char *buf_name,
int index, int16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int16_t value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (int32_t)var);
} }
TL_DEVICE void device_assert(bool cond) { assert(cond); } TL_DEVICE void device_assert(bool cond) { assert(cond); }
...@@ -290,4 +115,4 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) { ...@@ -290,4 +115,4 @@ TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
printf("Device assert failed: %s\n", msg); printf("Device assert failed: %s\n", msg);
assert(0); assert(0);
} }
} }
\ No newline at end of file
...@@ -19,9 +19,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): ...@@ -19,9 +19,24 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def test_debug_print_buffer(): def test_debug_print_buffer():
debug_print_buffer(16, 16, dtype="float") debug_print_buffer(dtype='bool')
debug_print_buffer(16, 16, dtype="float16") debug_print_buffer(dtype='int8')
debug_print_buffer(16, 16, dtype="uint8") debug_print_buffer(dtype='int16')
debug_print_buffer(dtype='int32')
debug_print_buffer(dtype='int64')
debug_print_buffer(dtype='uint8')
debug_print_buffer(dtype='uint16')
debug_print_buffer(dtype='uint32')
debug_print_buffer(dtype='uint64')
debug_print_buffer(dtype='float16')
debug_print_buffer(dtype='float32')
debug_print_buffer(dtype='float64')
debug_print_buffer(dtype='bfloat16')
debug_print_buffer(dtype='float8_e4m3')
debug_print_buffer(dtype='float8_e4m3fn')
debug_print_buffer(dtype='float8_e4m3fnuz')
debug_print_buffer(dtype='float8_e5m2')
debug_print_buffer(dtype='float8_e5m2fnuz')
def debug_print_buffer_conditional(M=16, N=16): def debug_print_buffer_conditional(M=16, N=16):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment