Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
e2b10c58
Unverified
Commit
e2b10c58
authored
Nov 25, 2025
by
Chaofan Lin
Committed by
GitHub
Nov 25, 2025
Browse files
[Language][UX] Semantic check for parallel fragment access (#1338)
parent
2ae4f1b7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
277 additions
and
3 deletions
+277
-3
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+7
-1
testing/python/analysis/test_tilelang_fragment_loop_checker.py
...ng/python/analysis/test_tilelang_fragment_loop_checker.py
+162
-0
testing/python/analysis/test_tilelang_nested_loop_checker.py
testing/python/analysis/test_tilelang_nested_loop_checker.py
+0
-0
tilelang/analysis/__init__.py
tilelang/analysis/__init__.py
+1
-0
tilelang/analysis/fragment_loop_checker.py
tilelang/analysis/fragment_loop_checker.py
+100
-0
tilelang/analysis/nested_loop_checker.py
tilelang/analysis/nested_loop_checker.py
+4
-2
tilelang/engine/phase.py
tilelang/engine/phase.py
+3
-0
No files found.
src/transform/layout_inference.cc
View file @
e2b10c58
...
...
@@ -821,7 +821,13 @@ private:
int64_t
frag_reg_num
=
1
;
for
(
auto
i
:
frag
.
value
()
->
OutputShape
())
{
auto
pci
=
as_const_int
(
i
);
ICHECK
(
pci
!=
nullptr
);
ICHECK
(
pci
!=
nullptr
)
<<
"Can not use non-constant range to "
"iterate over a fragment/local "
"buffer. Non-constant shape expr is: "
<<
i
<<
". This is possibly because you use symbolic shape when "
"accessing a fragment/local buffer."
;
frag_reg_num
*=
*
pci
;
}
reg_num
+=
frag_reg_num
;
...
...
testing/python/analysis/test_tilelang_fragment_loop_checker.py
0 → 100644
View file @
e2b10c58
import
tilelang
import
tilelang.language
as
T
import
pytest
@
tilelang
.
jit
def
simple_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_frag
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
Parallel
(
A
):
data_frag
[
i
]
=
0
return
main
@
tilelang
.
jit
def
nested_invalid_loop
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_frag
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
Parallel
(
A
//
64
):
for
j
in
T
.
Parallel
(
64
):
data_frag
[
i
*
64
+
j
]
=
0
return
main
@
tilelang
.
jit
def
invalid_loop_with_complex_dataflow
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_frag
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
Parallel
(
A
):
data_frag
[
64
//
2
+
i
%
64
]
=
0
return
main
@
tilelang
.
jit
def
valid_loop_not_use_loop_var
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_frag
=
T
.
alloc_fragment
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_frag
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
Parallel
(
A
):
# noqa: B007
for
j
in
T
.
Parallel
(
64
):
data_frag
[
j
]
=
0
# This is valid because we don't use i
return
main
@
tilelang
.
jit
def
valid_loop_not_frag
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_shared
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
Parallel
(
A
):
data_shared
[
i
]
=
0
# Valid because this is shared memory
return
main
@
tilelang
.
jit
def
valid_loop_serial
(
dtype
:
str
=
"bfloat16"
,
accum_dtype
:
str
=
"float32"
,
num_threads
:
int
=
128
):
A
=
T
.
dynamic
(
"A"
)
@
T
.
prim_func
def
main
(
data
:
T
.
Tensor
((
128
,
A
),
dtype
),
# type: ignore
):
with
T
.
Kernel
(
128
,
threads
=
num_threads
)
as
(
tid
,):
data_shared
=
T
.
alloc_shared
([
128
],
accum_dtype
)
for
i
in
T
.
Parallel
(
128
):
if
i
<
A
:
data_shared
[
i
]
=
data
[
tid
,
i
]
for
i
in
T
.
serial
(
A
):
data_shared
[
i
]
=
0
# Valid because this is serial
return
main
def
test_invalid_loop
():
with
pytest
.
raises
(
ValueError
):
simple_invalid_loop
()
with
pytest
.
raises
(
ValueError
):
nested_invalid_loop
()
with
pytest
.
raises
(
ValueError
):
invalid_loop_with_complex_dataflow
()
def
test_valid_loop
():
valid_loop_not_use_loop_var
()
valid_loop_not_frag
()
valid_loop_serial
()
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/
l
an
guage
/test_tilelang_
language_
nested_loop.py
→
testing/python/an
alysis
/test_tilelang_nested_loop
_checker
.py
View file @
e2b10c58
File moved
tilelang/analysis/__init__.py
View file @
e2b10c58
...
...
@@ -2,3 +2,4 @@
from
.ast_printer
import
ASTPrinter
# noqa: F401
from
.nested_loop_checker
import
NestedLoopChecker
# noqa: F401
from
.fragment_loop_checker
import
FragmentLoopChecker
# noqa: F401
tilelang/analysis/fragment_loop_checker.py
0 → 100644
View file @
e2b10c58
from
__future__
import
annotations
from
tvm
import
tir
from
tvm.tir
import
(
PyStmtExprVisitor
,
BufferStore
,
For
,
Var
,
PrimFunc
,
BufferLoad
,
IntImm
)
from
tvm.tir.transform
import
prim_func_pass
from
tvm.tir.stmt_functor
import
post_order_visit
@
tir
.
functor
.
visitor
class
_LoopVarUseAnalyzer
(
PyStmtExprVisitor
):
"""Analyze whether a loop variable is used in the given expr."""
def
__init__
(
self
,
var
:
Var
)
->
None
:
super
().
__init__
()
self
.
var
=
var
self
.
used
=
False
def
visit_var_
(
self
,
op
:
Var
)
->
None
:
if
op
==
self
.
var
:
self
.
used
=
True
# Don't recursively visit children to avoid infinite recursion
def
collect_local_buffer_accesses
(
statement
)
->
list
[
BufferLoad
|
BufferStore
]:
"""
Collect local buffer accesses in the loop body.
Args:
statement: The TIR statement to analyze
Returns:
Tuple of buffer accesses in the loop body.
"""
buffer_accesses
=
[]
def
visit_buffer_access
(
node
):
if
isinstance
(
node
,
(
BufferLoad
,
BufferStore
))
and
node
.
buffer
.
scope
().
startswith
(
"local"
):
buffer_accesses
.
append
(
node
)
post_order_visit
(
statement
,
visit_buffer_access
)
return
buffer_accesses
@
tir
.
functor
.
visitor
class
_FragmentLoopCheckVisitor
(
PyStmtExprVisitor
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
visit_for_
(
self
,
op
:
For
)
->
None
:
if
op
.
kind
==
tir
.
ForKind
.
PARALLEL
:
# Fuse consecutive parallel loops
# Other nested cases are all invalid in TileLang.
loops
=
[
op
]
child
=
op
.
body
while
isinstance
(
child
,
For
)
and
child
.
kind
==
tir
.
ForKind
.
PARALLEL
:
loops
.
append
(
child
)
child
=
child
.
body
loops_with_symbolic_ranges
=
[]
for
loop
in
loops
:
if
not
(
isinstance
(
loop
.
min
,
IntImm
)
and
isinstance
(
loop
.
extent
,
IntImm
)):
loops_with_symbolic_ranges
.
append
(
loop
)
if
len
(
loops_with_symbolic_ranges
)
>
0
:
buffer_accesses
=
collect_local_buffer_accesses
(
child
)
for
loop
in
loops_with_symbolic_ranges
:
for
buffer_access
in
buffer_accesses
:
indices
=
buffer_access
.
indices
analyzer
=
_LoopVarUseAnalyzer
(
loop
.
loop_var
)
for
index
in
indices
:
analyzer
.
visit_expr
(
index
)
if
analyzer
.
used
:
raise
ValueError
(
"[Tilelang Semantic Check] "
f
"Loop variable
{
loop
.
loop_var
}
in a T.Parallel loop with symbolic range (min=
{
loop
.
min
}
, extent=
{
loop
.
extent
}
) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang."
)
return
self
.
visit_stmt
(
op
.
body
)
def
FragmentLoopChecker
():
"""
When using T.Parallel over a local/fragment buffer, there are several restrictions:
to ensure that the parallelization is valid.
1. The range of loop can not be symbolic.
Returns:
A prim_func_pass that applies the transformation
"""
def
pass_fn
(
func
:
PrimFunc
,
mod
,
ctx
):
_FragmentLoopCheckVisitor
().
visit_stmt
(
func
.
body
)
return
func
return
prim_func_pass
(
pass_fn
,
opt_level
=
0
)
tilelang/analysis/nested_loop_checker.py
View file @
e2b10c58
...
...
@@ -35,7 +35,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
if
self
.
in_parallel_context
:
raise
ValueError
(
"Nested parallel loops are not allowed. "
raise
ValueError
(
"[Tilelang Semantic Check] "
"Nested parallel loops are not allowed. "
"Please check your loop structure."
)
self
.
in_parallel_context
=
True
self
.
visit_stmt
(
child
)
...
...
@@ -43,7 +44,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
return
elif
is_pipelined_for
(
op
):
if
self
.
in_parallel_context
:
raise
ValueError
(
"Pipelined loop cannot be nested inside a parallel loop. "
raise
ValueError
(
"[Tilelang Semantic Check] "
"Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure."
)
self
.
visit_stmt
(
op
.
body
)
...
...
tilelang/engine/phase.py
View file @
e2b10c58
...
...
@@ -80,6 +80,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
# Check if there are any invalid nested loops.
tilelang
.
analysis
.
NestedLoopChecker
()(
mod
)
# Check if there are any invalid symbolic T.Parallel + fragment access.
tilelang
.
analysis
.
FragmentLoopChecker
()(
mod
)
def
LowerAndLegalize
(
mod
:
IRModule
,
target
:
Target
)
->
IRModule
:
# Bind the target device information to the module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment