Unverified Commit e2b10c58 authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Language][UX] Semantic check for parallel fragment access (#1338)

parent 2ae4f1b7
......@@ -821,7 +821,13 @@ private:
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
auto pci = as_const_int(i);
ICHECK(pci != nullptr);
ICHECK(pci != nullptr)
<< "Can not use non-constant range to "
"iterate over a fragment/local "
"buffer. Non-constant shape expr is: "
<< i
<< ". This is possibly because you use symbolic shape when "
"accessing a fragment/local buffer.";
frag_reg_num *= *pci;
}
reg_num += frag_reg_num;
......
import tilelang
import tilelang.language as T
import pytest
@tilelang.jit
def simple_invalid_loop(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_frag[i] = data[tid, i]
for i in T.Parallel(A):
data_frag[i] = 0
return main
@tilelang.jit
def nested_invalid_loop(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_frag[i] = data[tid, i]
for i in T.Parallel(A // 64):
for j in T.Parallel(64):
data_frag[i * 64 + j] = 0
return main
@tilelang.jit
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_frag[i] = data[tid, i]
for i in T.Parallel(A):
data_frag[64 // 2 + i % 64] = 0
return main
@tilelang.jit
def valid_loop_not_use_loop_var(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_frag = T.alloc_fragment([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_frag[i] = data[tid, i]
for i in T.Parallel(A): # noqa: B007
for j in T.Parallel(64):
data_frag[j] = 0 # This is valid because we don't use i
return main
@tilelang.jit
def valid_loop_not_frag(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_shared = T.alloc_shared([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_shared[i] = data[tid, i]
for i in T.Parallel(A):
data_shared[i] = 0 # Valid because this is shared memory
return main
@tilelang.jit
def valid_loop_serial(dtype: str = "bfloat16",
accum_dtype: str = "float32",
num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
def main(
data: T.Tensor((128, A), dtype), # type: ignore
):
with T.Kernel(128, threads=num_threads) as (tid,):
data_shared = T.alloc_shared([128], accum_dtype)
for i in T.Parallel(128):
if i < A:
data_shared[i] = data[tid, i]
for i in T.serial(A):
data_shared[i] = 0 # Valid because this is serial
return main
def test_invalid_loop():
with pytest.raises(ValueError):
simple_invalid_loop()
with pytest.raises(ValueError):
nested_invalid_loop()
with pytest.raises(ValueError):
invalid_loop_with_complex_dataflow()
def test_valid_loop():
valid_loop_not_use_loop_var()
valid_loop_not_frag()
valid_loop_serial()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -2,3 +2,4 @@
from .ast_printer import ASTPrinter # noqa: F401
from .nested_loop_checker import NestedLoopChecker # noqa: F401
from .fragment_loop_checker import FragmentLoopChecker # noqa: F401
from __future__ import annotations
from tvm import tir
from tvm.tir import (PyStmtExprVisitor, BufferStore, For, Var, PrimFunc, BufferLoad, IntImm)
from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import post_order_visit
@tir.functor.visitor
class _LoopVarUseAnalyzer(PyStmtExprVisitor):
"""Analyze whether a loop variable is used in the given expr."""
def __init__(self, var: Var) -> None:
super().__init__()
self.var = var
self.used = False
def visit_var_(self, op: Var) -> None:
if op == self.var:
self.used = True
# Don't recursively visit children to avoid infinite recursion
def collect_local_buffer_accesses(statement) -> list[BufferLoad | BufferStore]:
"""
Collect local buffer accesses in the loop body.
Args:
statement: The TIR statement to analyze
Returns:
Tuple of buffer accesses in the loop body.
"""
buffer_accesses = []
def visit_buffer_access(node):
if isinstance(node, (BufferLoad, BufferStore)) and node.buffer.scope().startswith("local"):
buffer_accesses.append(node)
post_order_visit(statement, visit_buffer_access)
return buffer_accesses
@tir.functor.visitor
class _FragmentLoopCheckVisitor(PyStmtExprVisitor):
def __init__(self) -> None:
super().__init__()
def visit_for_(self, op: For) -> None:
if op.kind == tir.ForKind.PARALLEL:
# Fuse consecutive parallel loops
# Other nested cases are all invalid in TileLang.
loops = [op]
child = op.body
while isinstance(child, For) and child.kind == tir.ForKind.PARALLEL:
loops.append(child)
child = child.body
loops_with_symbolic_ranges = []
for loop in loops:
if not (isinstance(loop.min, IntImm) and isinstance(loop.extent, IntImm)):
loops_with_symbolic_ranges.append(loop)
if len(loops_with_symbolic_ranges) > 0:
buffer_accesses = collect_local_buffer_accesses(child)
for loop in loops_with_symbolic_ranges:
for buffer_access in buffer_accesses:
indices = buffer_access.indices
analyzer = _LoopVarUseAnalyzer(loop.loop_var)
for index in indices:
analyzer.visit_expr(index)
if analyzer.used:
raise ValueError(
"[Tilelang Semantic Check] "
f"Loop variable {loop.loop_var} in a T.Parallel loop with symbolic range (min={loop.min}, extent={loop.extent}) is used to index "
"a local/fragment buffer, which is not allowed in Tilelang.")
return
self.visit_stmt(op.body)
def FragmentLoopChecker():
"""
When using T.Parallel over a local/fragment buffer, there are several restrictions:
to ensure that the parallelization is valid.
1. The range of loop can not be symbolic.
Returns:
A prim_func_pass that applies the transformation
"""
def pass_fn(func: PrimFunc, mod, ctx):
_FragmentLoopCheckVisitor().visit_stmt(func.body)
return func
return prim_func_pass(pass_fn, opt_level=0)
......@@ -35,7 +35,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
# Otherwise
if self.in_parallel_context:
raise ValueError("Nested parallel loops are not allowed. "
raise ValueError("[Tilelang Semantic Check] "
"Nested parallel loops are not allowed. "
"Please check your loop structure.")
self.in_parallel_context = True
self.visit_stmt(child)
......@@ -43,7 +44,8 @@ class _NestedLoopCheckVisitor(PyStmtExprVisitor):
return
elif is_pipelined_for(op):
if self.in_parallel_context:
raise ValueError("Pipelined loop cannot be nested inside a parallel loop. "
raise ValueError("[Tilelang Semantic Check] "
"Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.")
self.visit_stmt(op.body)
......
......@@ -80,6 +80,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)
# Check if there are any invalid symbolic T.Parallel + fragment access.
tilelang.analysis.FragmentLoopChecker()(mod)
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment