Commit 951c2300 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Reorder Passes: Place Vectorize Loop Before StorageFlatten and...

[Bugfix] Reorder Passes: Place Vectorize Loop Before StorageFlatten and FlattenBuffer to Prevent Redundant Allocations (#37)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm

* [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md

* [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md

* update license

* [Enhancement] add tensor supply handling for unsigned integers; improve error message for execution backend assertion

* [Refactor] improve code readability by reformatting function signatures and assertions

* [Refactor] replace torch.manual_seed with tilelang.testing.set_random_seed for consistency in random seed handling

* [Refactor] unify thread binding variable naming across kernel and example files

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] remove unused thread binding parameter from matrix multiplication functions

* [Refactor] enable main testing function in tilelang kernel gemm test

* bug fix

* lint fix

* [Refactor] reorder vectorize loop
parent 362b3520
...@@ -144,6 +144,8 @@ def lower( ...@@ -144,6 +144,8 @@ def lower(
mod = tl.transform.LegalizeSafeMemoryAccess()(mod) mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Inject Simplify to remove the duplicated conditions # Inject Simplify to remove the duplicated conditions
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90": if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod) mod = tl.transform.MultiVersionBuffer()(mod)
...@@ -161,7 +163,6 @@ def lower( ...@@ -161,7 +163,6 @@ def lower(
mod = tir.transform.FlattenBuffer()(mod) mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
......
...@@ -463,6 +463,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: ...@@ -463,6 +463,8 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
elif isinstance(res, str): elif isinstance(res, str):
# Ignore docstrings # Ignore docstrings
pass pass
elif isinstance(res, tvm.tir.stmt.BufferStore):
T.buffer_store(res.buffer, res.value, res.indices, res.predicate)
else: else:
self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") self.report_error(node, f"Parsing resulted in unexpected type {type(res)}")
...@@ -480,13 +482,8 @@ def visit_if(self: Parser, node: doc.If) -> None: ...@@ -480,13 +482,8 @@ def visit_if(self: Parser, node: doc.If) -> None:
The doc AST if node. The doc AST if node.
""" """
with self.var_table.with_frame(): with self.var_table.with_frame():
condition = self.eval_expr(node.test) predicate = self.eval_expr(node.test)
if isinstance(condition, bool): if isinstance(predicate, (PrimExpr, tvm.tir.expr.ExprOp)):
if condition:
self.visit_body(node.body)
elif node.orelse:
self.visit_body(node.orelse)
else:
with T.If(self.eval_expr(node.test)): with T.If(self.eval_expr(node.test)):
with T.Then(): with T.Then():
with self.var_table.with_frame(): with self.var_table.with_frame():
...@@ -495,6 +492,16 @@ def visit_if(self: Parser, node: doc.If) -> None: ...@@ -495,6 +492,16 @@ def visit_if(self: Parser, node: doc.If) -> None:
with T.Else(): with T.Else():
with self.var_table.with_frame(): with self.var_table.with_frame():
self.visit_body(node.orelse) self.visit_body(node.orelse)
elif isinstance(predicate, bool):
if predicate:
with self.var_table.with_frame():
self.visit_body(node.body)
elif node.orelse:
with self.var_table.with_frame():
self.visit_body(node.orelse)
else:
self.report_error(node.test,
f"If condition must be a boolean expression, but got {predicate}")
@dispatch.register(token="tir", type_name="Assert") @dispatch.register(token="tir", type_name="Assert")
...@@ -529,6 +536,8 @@ def visit_return(self: Parser, node: doc.Return) -> None: ...@@ -529,6 +536,8 @@ def visit_return(self: Parser, node: doc.Return) -> None:
The doc AST return node. The doc AST return node.
""" """
value = self.eval_expr(node.value) value = self.eval_expr(node.value)
if value is None:
self.report_error(node, "Expression to be returned must be a PrimExpr")
T.evaluate(tvm.tir.ret(value)) T.evaluate(tvm.tir.ret(value))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment