Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
5475f8e7
"...bert-large_oneflow.git" did not exist on "5988d2cc317ac8cb8e21f84ec17dbd59e805df6c"
Unverified
Commit
5475f8e7
authored
Oct 27, 2025
by
Yuqi Dong
Committed by
GitHub
Oct 27, 2025
Browse files
[Feature]:Add device assert (#1116)
* update * update
parent
17a63976
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
2 deletions
+68
-2
src/tl_templates/cuda/debug.h
src/tl_templates/cuda/debug.h
+9
-0
testing/python/debug/test_device_assert.py
testing/python/debug/test_device_assert.py
+36
-0
tilelang/language/__init__.py
tilelang/language/__init__.py
+1
-1
tilelang/language/print.py
tilelang/language/print.py
+22
-1
No files found.
src/tl_templates/cuda/debug.h
View file @
5475f8e7
...
...
@@ -257,3 +257,12 @@ __device__ void debug_print_buffer_value<int16_t>(const char *msg,
msg
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
threadIdx
.
x
,
threadIdx
.
y
,
threadIdx
.
z
,
buf_name
,
index
,
(
int32_t
)
var
);
}
TL_DEVICE
void
device_assert
(
bool
cond
)
{
assert
(
cond
);
}
TL_DEVICE
void
device_assert_with_msg
(
bool
cond
,
const
char
*
msg
)
{
if
(
!
cond
)
{
printf
(
"Device assert failed: %s
\n
"
,
msg
);
assert
(
0
);
}
}
testing/python/debug/test_device_assert.py
0 → 100644
View file @
5475f8e7
# type: ignore
import
tilelang
import
tilelang.testing
import
tilelang.language
as
T
# TODO(dyq) It intentionally triggers a device-side assert so we can't include this in CI
# Please run manually when you want to verify that device_assert actually traps on GPU.
def
_manual_device_assert_triggered
():
@
T
.
prim_func
def
program
():
with
T
.
Kernel
(
threads
=
128
):
tid
=
T
.
get_thread_binding
()
T
.
device_assert
(
tid
>
0
,
"Assertion Trigger !"
)
jit_kernel
=
tilelang
.
compile
(
program
,
target
=
"cuda"
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
.
run_once
()
def
test_device_assert_no_trigger
():
@
T
.
prim_func
def
program
():
with
T
.
Kernel
(
threads
=
128
):
tid
=
T
.
get_thread_binding
()
T
.
device_assert
(
tid
==
tid
)
jit_kernel
=
tilelang
.
compile
(
program
,
target
=
"cuda"
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
.
run_once
()
if
__name__
==
"__main__"
:
_manual_device_assert_triggered
()
tilelang/language/__init__.py
View file @
5475f8e7
...
...
@@ -64,7 +64,7 @@ from .reduce import (
cumsum
,
# noqa: F401
finalize_reducer
,
# noqa: F401
)
from
.print
import
print
# noqa: F401
from
.print
import
print
,
device_assert
# noqa: F401
from
.customize
import
(
atomic_max
,
# noqa: F401
atomic_min
,
# noqa: F401
...
...
tilelang/language/print.py
View file @
5475f8e7
"""
This module provides macros and utilities for debugging TileLang (tl) programs.
It includes functionality to print variables, print values in buffers,
and
conditionally execute debug prints.
It includes functionality to print variables, print values in buffers, conditionally execute debug prints
and assert
.
"""
from
tvm
import
tir
...
...
@@ -133,6 +133,27 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr,
buffer
[
coords
])
from
tilelang.utils.target
import
check_cuda_availability
import
warnings
_IS_CUDA_AVAILABLE
=
check_cuda_availability
()
@
macro
def
device_assert
(
condition
:
tir
.
PrimExpr
,
msg
:
str
=
""
):
"""
Device-side assert emulation.
Emits a device-side assert call on CUDA targets when CUDA is available.
The assert is always enabled and cannot be disabled at runtime.
"""
if
_IS_CUDA_AVAILABLE
:
if
msg
==
""
:
tir
.
call_extern
(
"void"
,
"device_assert"
,
condition
)
else
:
warnings
.
warn
(
"Non-empty msg may slightly slow down the kernel"
,
stacklevel
=
2
)
tir
.
call_extern
(
"void"
,
"device_assert_with_msg"
,
condition
,
msg
)
def
print
(
obj
:
Any
,
msg
:
str
=
""
,
warp_group_id
:
int
=
0
,
warp_id
:
int
=
0
)
->
tir
.
PrimExpr
:
"""
A generic print function that handles both TIR buffers and primitive expressions.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment