Commit bb60f6ce authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Support debug print for unsigned char datatype (#145)

* Fix debug print buffer template for unsigned char type

- Update debug_print_buffer_value template specialization for unsigned char
- Modify test_tilelang_debug_print.py to include additional dtype tests
- Add test case for uint8 dtype in debug print buffer function

* Refactor debug print buffer template formatting for unsigned char

- Improve code formatting for debug_print_buffer_value template specialization
- Adjust line breaks and indentation for better readability
- Maintain consistent code style with other template specializations
parent 37d44f24
...@@ -101,9 +101,9 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name, ...@@ -101,9 +101,9 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
// Specialization for unsiged char type // Specialization for unsiged char type
template <> template <>
__device__ void debug_print_buffer_value<char>(const char *msg, __device__ void
const char *buf_name, int index, debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
char var) { int index, unsigned char var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, " printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=char value=%d\n", "index=%d, dtype=char value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
......
...@@ -5,8 +5,7 @@ import tilelang.testing ...@@ -5,8 +5,7 @@ import tilelang.testing
import tilelang.language as T import tilelang.language as T
def debug_print_buffer(M=16, N=16): def debug_print_buffer(M=16, N=16, dtype="float16"):
dtype = "float16"
@T.prim_func @T.prim_func
def program(Q: T.Buffer((M, N), dtype)): def program(Q: T.Buffer((M, N), dtype)):
...@@ -20,7 +19,9 @@ def debug_print_buffer(M=16, N=16): ...@@ -20,7 +19,9 @@ def debug_print_buffer(M=16, N=16):
def test_debug_print_buffer(): def test_debug_print_buffer():
debug_print_buffer(16, 16) debug_print_buffer(16, 16, dtype="float")
debug_print_buffer(16, 16, dtype="float16")
debug_print_buffer(16, 16, dtype="uint8")
def debug_print_buffer_conditional(M=16, N=16): def debug_print_buffer_conditional(M=16, N=16):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment