simplify.py 1.91 KB
Newer Older
1
2
3
4
5
6
7
from tilelang import tvm as tvm
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union, Callable
from . import _ffi_api


8
9
10
11
12
13
14
15
16
17
18
def LetInline():
    """LetInline

    Returns
    -------
    fpass : tvm.transform.Pass
        The result pass
    """
    return _ffi_api.LetInline()  # type: ignore


19
def Simplify(simplify_arguments: bool = False):
20
21
22
23
24
25
26
    """Simplify

    Returns
    -------
    fpass : tvm.transform.Pass
        The result pass
    """
27
    return _ffi_api.Simplify(simplify_arguments)  # type: ignore
28
29


30
31
def _Simplify(stmt: Union[PrimFunc, IRModule],
              inline_let: bool = False) -> Union[PrimFunc, IRModule]:
32
    if isinstance(stmt, PrimFunc):
33
34
35
36
37
        if inline_let:
            mod = LetInline()(IRModule.from_expr(stmt))
            mod = Simplify(simplify_arguments=True)(mod)
        else:
            mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
38
39
40
        assert len(mod.functions) == 1, "Simplify should return a single function"
        return list(mod.functions.values()).pop()
    elif isinstance(stmt, IRModule):
41
42
43
44
45
46
47
        if inline_let:
            mod = LetInline()(stmt)
            mod = Simplify(simplify_arguments=True)(mod)
        else:
            mod = Simplify(simplify_arguments=True)(stmt)
        assert len(mod.functions) == 1, "Simplify should return a single function"
        return list(mod.functions.values()).pop()
48
49
50
51
52
53
54
55
56
57
58
59
    else:
        raise ValueError(f"Unsupported type: {type(stmt)}")


# Decorator to simplify the output of a function
def simplify_prim_func(func: Callable) -> Callable:

    def wrapper(*args, **kwargs):
        stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
        return _Simplify(stmt)

    return wrapper
60
61


62
63
def apply_simplify(stmt: Union[PrimFunc, IRModule],
                   inline_let: bool = False) -> Union[PrimFunc, IRModule]:
64
    """Apply Simplify pass to a PrimFunc or IRModule."""
65
    return _Simplify(stmt, inline_let)