"...bert-large_oneflow.git" did not exist on "5988d2cc317ac8cb8e21f84ec17dbd59e805df6c"
Unverified Commit 5475f8e7 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Feature]:Add device assert (#1116)

* update

* update
parent 17a63976
......@@ -257,3 +257,12 @@ __device__ void debug_print_buffer_value<int16_t>(const char *msg,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (int32_t)var);
}
TL_DEVICE void device_assert(bool cond) { assert(cond); }
TL_DEVICE void device_assert_with_msg(bool cond, const char *msg) {
if (!cond) {
printf("Device assert failed: %s\n", msg);
assert(0);
}
}
# type: ignore
import tilelang
import tilelang.testing
import tilelang.language as T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
def _manual_device_assert_triggered():
@T.prim_func
def program():
with T.Kernel(threads=128):
tid = T.get_thread_binding()
T.device_assert(tid > 0, "Assertion Trigger !")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
def test_device_assert_no_trigger():
@T.prim_func
def program():
with T.Kernel(threads=128):
tid = T.get_thread_binding()
T.device_assert(tid == tid)
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()
if __name__ == "__main__":
_manual_device_assert_triggered()
......@@ -64,7 +64,7 @@ from .reduce import (
cumsum, # noqa: F401
finalize_reducer, # noqa: F401
)
from .print import print # noqa: F401
from .print import print, device_assert # noqa: F401
from .customize import (
atomic_max, # noqa: F401
atomic_min, # noqa: F401
......
"""
This module provides macros and utilities for debugging TileLang (tl) programs.
It includes functionality to print variables, print values in buffers, and conditionally execute debug prints.
It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert.
"""
from tvm import tir
......@@ -133,6 +133,27 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
buffer[coords])
from tilelang.utils.target import check_cuda_availability
import warnings
_IS_CUDA_AVAILABLE = check_cuda_availability()
@macro
def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
Device-side assert emulation.
Emits a device-side assert call on CUDA targets when CUDA is available.
The assert is always enabled and cannot be disabled at runtime.
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
tir.call_extern("void", "device_assert", condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg)
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
"""
A generic print function that handles both TIR buffers and primitive expressions.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment