"vscode:/vscode.git/clone" did not exist on "4bfdd7692ea67acf074ce661a2df16de4e4f0890"
test_tilelang_issue_merge_if.py 869 Bytes
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tilelang
from tilelang import tvm as tvm
from tvm.ir import IRModule
import tilelang.testing
import tilelang.language as T


def merge_if_test():

    @T.prim_func
    def main():
        A = T.alloc_fragment((1,), "float16")
        B = T.alloc_fragment((1,), "float16")
        C = T.alloc_fragment((1,), "float16")
        D = T.alloc_fragment((1,), "float16")
        if A[0] == 0:
            A[0] = 0
        if B[0] == 0:
            B[0] = 0
        if C[0] == 0:
            C[0] = 0
        if D[0] == 0:
            D[0] = 0

    return main


def test_merge_if():
    func = merge_if_test()
    original_module = IRModule.from_expr(func)
    transformed = tilelang.transform.MergeIfStmt()(original_module)
    tvm.ir.assert_structural_equal(original_module["main"], transformed["main"], True)


if __name__ == "__main__":
    tilelang.testing.main()