Unverified Commit 5eb30a4f authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Language] Add missing while statement (#1254)

* add typing stub for tir.ir

* remove idents

* minor update

* [Language] Add missing while statement

* add test
parent 2c0072a8
...@@ -342,5 +342,23 @@ def test_swap_logic(): ...@@ -342,5 +342,23 @@ def test_swap_logic():
torch.testing.assert_close(data, ref) torch.testing.assert_close(data, ref)
def test_while_loop():
@tilelang.jit(out_idx=-1)
@T.prim_func
def test_while_loop(A: T.Tensor((1,), T.int32)):
with T.Kernel(1) as _:
i = T.alloc_var(T.int32, 0)
sum = T.alloc_var(T.int32)
while i < 10:
sum += i
i += 1
A[0] = sum
ker = test_while_loop()
A = ker()
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"
if __name__ == '__main__': if __name__ == '__main__':
tilelang.testing.main() tilelang.testing.main()
...@@ -469,6 +469,7 @@ class DSLMutator(ast.NodeTransformer): ...@@ -469,6 +469,7 @@ class DSLMutator(ast.NodeTransformer):
return self._emit_assign_target(node.target, rval, annot=node.annotation) return self._emit_assign_target(node.target, rval, annot=node.annotation)
def visit_While(self, node): def visit_While(self, node):
node = self.generic_visit(node)
return quote1( return quote1(
"for _ in __tb.ctx_while(lambda: cond):\n pass", "for _ in __tb.ctx_while(lambda: cond):\n pass",
cond=node.test, cond=node.test,
......
...@@ -292,7 +292,22 @@ class Builder(BaseBuilder): ...@@ -292,7 +292,22 @@ class Builder(BaseBuilder):
def ctx_while(self, cond): def ctx_while(self, cond):
self.check_continue_break() self.check_continue_break()
raise RuntimeError("while loops are not supported in TileLang builder") cond_v = cond()
cond_v_unwrap = unwrap_cond(cond_v)
if not isinstance(cond_v_unwrap, PrimExpr):
if cond_v_unwrap:
raise RuntimeError(
f'Infinite while loop detected in TileLang\n'
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n'
)
else:
logger.warning(
'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n',
f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n',
stack_info=True,
stacklevel=2)
with self.with_frame(tir.While(cond_v_unwrap)):
yield None
def bind(self, name, value, annot=BaseBuilder.empty): def bind(self, name, value, annot=BaseBuilder.empty):
self.check_continue_break() self.check_continue_break()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment