Commit 8cdc185b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Debug] Introduce `T.print` for buffer and variables logging on frontend (#45)

* [Doc] Update documentation structure and content: add overview section, revise project name, and change theme to Furo

* [Feature] Add device-side debug printing functions and integrate into kernel interface

* lint fix

* remove debug print

* implement test for debug

* lint fix

* add some comments

* Enhance fragment design and assert fragment print

* enhance debug print

* add test for msg

* lint fix
parent 22246c65
......@@ -76,3 +76,6 @@ models/frozenmodels/
# build sdist
build_sdist/
# exclude debug testing folder
!testing/python/debug
......@@ -83,6 +83,7 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
decl_stream << "#include <tl_templates/cuda/debug.h>\n";
decl_stream << "\n";
return CodeGenC::Finish();
}
......
#pragma once
#include "common.h"
#include <stdio.h>
// Template declaration for device-side debug printing (variable only)
template <typename T> __device__ void debug_print_var(char *msg, T var);
// Specialization for integer type
template <> __device__ void debug_print_var<int>(char *msg, int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for float type
template <> __device__ void debug_print_var<float>(char *msg, float var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for half type
template <> __device__ void debug_print_var<half>(char *msg, half var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for half_t type
template <> __device__ void debug_print_var<half_t>(char *msg, half_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=half_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for bfloat16_t type
template <>
__device__ void debug_print_var<bfloat16_t>(char *msg, bfloat16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): "
"dtype=bfloat16_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Specialization for double type
template <> __device__ void debug_print_var<double>(char *msg, double var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=double "
"value=%lf\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
#pragma once
#include "common.h"
#include <stdio.h>
// Template declaration for device-side debug printing (buffer only)
template <typename T>
__device__ void debug_print_buffer_value(char *msg, char *buf_name, int index,
T var);
// Specialization for integer type
template <>
__device__ void debug_print_buffer_value<int>(char *msg, char *buf_name,
int index, int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for float type
template <>
__device__ void debug_print_buffer_value<float>(char *msg, char *buf_name,
int index, float var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=float value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for half type
template <>
__device__ void debug_print_buffer_value<half>(char *msg, char *buf_name,
int index, half var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for half_t type
template <>
__device__ void debug_print_buffer_value<half_t>(char *msg, char *buf_name,
int index, half_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=half_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for bfloat16_t type
template <>
__device__ void debug_print_buffer_value<bfloat16_t>(char *msg, char *buf_name,
int index,
bfloat16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=bfloat16_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for double type
template <>
__device__ void debug_print_buffer_value<double>(char *msg, char *buf_name,
int index, double var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=double value=%lf\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
# type: ignore
import tilelang
import tilelang.testing
import tilelang.language as T
def debug_print_buffer(M=16, N=16):
dtype = "float16"
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_debug_print_buffer():
debug_print_buffer(16, 16)
def debug_print_buffer_conditional(M=16, N=16):
dtype = "float16"
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_shared([M, N], dtype)
if bx == 0 and by == 0 and bz == 0:
T.print(shared_buf)
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_debug_print_buffer_conditional():
debug_print_buffer_conditional(16, 16)
def debug_print_value_conditional(M=16, N=16):
dtype = "float16"
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding()
if tid == 0:
T.print(bx + by + bz)
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_debug_print_value_conditional():
debug_print_value_conditional(16, 16)
def debug_print_register_files(M=16, N=16):
dtype = "float16"
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
shared_buf = T.alloc_fragment([M, N], dtype)
for i, j in T.Parallel(M, N):
T.print(shared_buf[i, j])
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_debug_print_register_files():
debug_print_register_files(16, 16)
def debug_print_msg(M=16, N=16):
dtype = "float16"
@T.prim_func
def program(Q: T.Buffer((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
tid = T.get_thread_binding()
if tid == 0:
T.print(bx + by + bz, msg="hello world")
jit_kernel = tilelang.JITKernel(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_debug_print_msg():
debug_print_msg(16, 16)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -8,7 +8,7 @@ from .parser import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .kernel import Kernel, KernelLaunchFrame # noqa: F401
from .kernel import Kernel, KernelLaunchFrame, get_thread_binding # noqa: F401
from .allocate import (
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
......@@ -24,6 +24,7 @@ from .reduce import (
reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401
)
from .print import print # noqa: F401
from .customize import (
atomic_add, # noqa: F401
atomic_addx2, # noqa: F401
......
......@@ -132,6 +132,13 @@ class KernelLaunchFrame(TIRFrame):
"""
return self.frames[-4 + dim].iter_var.var
def get_thread_bindings(self) -> List[Var]:
"""
Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
"""
return [frame.iter_var.var for frame in self.frames[-4:-1]]
def get_num_threads(self) -> int:
"""
Returns the thread indices from the topmost frame.
......@@ -213,3 +220,15 @@ def Kernel(
attrs["pragma_import_c"] = prelude
return _ffi_api.KernelLaunch(blocks, threads, attrs)
def get_thread_binding(dim: int = 0) -> Var:
"""Returns the thread binding for the given dimension.
"""
return KernelLaunchFrame.Current().get_thread_binding(dim)
def get_thread_bindings() -> List[Var]:
"""Returns all three thread bindings.
"""
return KernelLaunchFrame.Current().get_thread_bindings()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module provides macros and utilities for debugging TileLang (tl) programs.
It includes functionality to print variables, print values in buffers, and conditionally execute debug prints.
"""
from tvm import tir
from typing import Any
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import macro, serial
@macro
def print_var(var: tir.PrimExpr, msg: str = "") -> tir.PrimExpr:
"""
Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.
Parameters:
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
tir.call_extern("handle", "debug_print_var", msg, var)
@macro
def print_var_with_condition(condition: tir.PrimExpr,
var: tir.PrimExpr,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.
"""
if condition:
tir.call_extern("handle", "debug_print_var", msg, var)
@macro
def print_flat_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
"""
if condition:
# Iterate through the buffer elements and print each one.
for i in serial(elems):
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[i])
def print(obj: Any, msg: str = "") -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
- If the input is a TIR buffer, it prints its values, but only on the first thread (tx=0, ty=0, tz=0).
- If the input is a TIR primitive expression, it prints its value directly.
Parameters:
obj (Any): The object to print. It can be either a tir.Buffer or tir.PrimExpr.
msg (str): An optional message to include in the print statement.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
Raises:
ValueError: If the input object type is unsupported.
"""
if isinstance(obj, tir.Buffer):
# Buffers must be printed in just one thread to avoid duplicate outputs.
# Retrieve the thread bindings for thread x, y, and z.
tx, ty, tz = get_thread_bindings()
# Flatten the buffer for consistent printing. This assumes a 1D flattened buffer.
buffer = obj.get_flattened_buffer()
if buffer.scope() == "local.fragment":
raise NotImplementedError("Printing fragment buffers currently is not supported.")
assert len(buffer.shape) == 1, "Buffer must be flattened into a 1D shape."
# Get the number of elements in the buffer.
elems = buffer.shape[-1]
# Ensure only the first thread (tx=0, ty=0, tz=0) executes the print.
condition = (tx == 0 and ty == 0 and tz == 0)
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_flat_buffer_with_condition(condition, buffer, elems, msg)
elif isinstance(obj, tir.PrimExpr):
if not msg:
msg = f"expr<{obj}>"
# Directly print primitive expressions.
return print_var(obj, msg)
else:
# Unsupported object type.
raise ValueError(
f"Unexpected type: {type(obj)}. Supported types are tir.Buffer and tir.PrimExpr.")
......@@ -13,17 +13,31 @@ from tilelang.layout import Layout
@tvm._ffi.register_object("tl.Fragment")
class Fragment(Layout):
# pylint: disable=super-init-not-called
def __init__(self, shape, forward_thread_fn, replicate=1, forward_index_fn=None):
def __init__(self,
shape,
forward_fn=None,
forward_thread_fn=None,
replicate=1,
forward_index_fn=None):
forward_vars = []
for idx, size in enumerate(shape):
iv = IterVar(Range(0, size), Var(f"i{idx}", "int32"), 0)
forward_vars.append(iv)
vars = [iv.var for iv in forward_vars]
forward_index = forward_index_fn(*vars) if forward_index_fn else None
if not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index]
forward_thread: IterVar = None
forward_index: tvm.ir.container.Array = None
thread_replicate: IterVar = None
if forward_fn is not None:
if replicate > 1:
thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0)
forward_thread, forward_index = forward_fn(*vars, thread_replicate)
else:
thread_replicate = None
forward_thread, forward_index = forward_fn(*vars)
else:
forward_index = forward_index_fn(*vars) if forward_index_fn else None
if replicate > 1:
thread_replicate = IterVar(Range(0, replicate), Var("rep", "int32"), 0)
forward_thread = forward_thread_fn(*vars, thread_replicate.var)
......@@ -31,6 +45,9 @@ class Fragment(Layout):
thread_replicate = None
forward_thread = forward_thread_fn(*vars)
if not isinstance(forward_index, tvm.ir.container.Array):
forward_index = [forward_index]
self.__init_handle_by_constructor__(
_ffi_api.Fragment,
forward_vars,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment