"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "6963b07bf024f069d3a0731d72047502893b3d69"
Unverified Commit 1dfac2e8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Use `ExprDeepEqual` instead of `StructuralEqual` when merge consecutive If stmt (#876)

* Update submodule TVM to latest commit and fix condition comparison in merge_if_stmt.cc

* Update submodule TVM to latest commit 0524f760

* lint fix
parent 15a303d2
......@@ -39,7 +39,7 @@ private:
if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
if (!if_node->else_case.defined()) {
if (current_condition.defined() &&
StructuralEqual()(current_condition, if_node->condition)) {
ExprDeepEqual()(current_condition, if_node->condition)) {
current_if_bodies.push_back(if_node->then_case);
continue;
} else {
......
import tilelang
from tilelang import tvm as tvm
from tvm.ir import IRModule
import tilelang.testing
import tilelang.language as T
def merge_if_test():
@T.prim_func
def main():
A = T.alloc_fragment((1,), "float16")
B = T.alloc_fragment((1,), "float16")
C = T.alloc_fragment((1,), "float16")
D = T.alloc_fragment((1,), "float16")
if A[0] == 0:
A[0] = 0
if B[0] == 0:
B[0] = 0
if C[0] == 0:
C[0] = 0
if D[0] == 0:
D[0] = 0
return main
def test_merge_if():
func = merge_if_test()
original_module = IRModule.from_expr(func)
transformed = tilelang.transform.MergeIfStmt()(original_module)
tvm.ir.assert_structural_equal(original_module["main"], transformed["main"], True)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment