"src/onnx/parse_unique.cpp" did not exist on "44463b94989bfe3f3849ed29629576abd53a9976"
builtin.py 6.48 KB
Newer Older
1
2
"""The language interface for tl programs."""

3
4
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier
5
from tvm import tir
6
7
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call
8
9


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def create_list_of_mbarrier(*args: Any) -> Call:
    """
    Create a list of memory barrier handles.

    Parameters
    ----------
    *args : list or Any
        Either a single list of arguments, or multiple arguments directly.

    Returns
    -------
    tvm.tir.Call
        Handle to the created list of memory barriers.

    Raises
    ------
    TypeError
        If the input is not a list or variadic arguments.
    
    Examples
    --------
    >>> create_list_of_mbarrier([128, 128])
    >>> create_list_of_mbarrier(128, 128)
33
    """
34
35
36
37
38
39
    if len(args) == 1 and isinstance(args[0], list):
        return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args[0])
    elif len(args) >= 1:
        return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
    else:
        raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.")
40
41


42
def get_mbarrier(*args):
43
44
45
46
47
48
49
50
    """Retrieve a memory barrier operation.

    Args:
        *args: Variable arguments to specify which memory barrier to retrieve

    Returns:
        tir.Call: A handle to the requested memory barrier
    """
51
    return tir.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), *args)
52
53


54
def create_tma_descriptor(*args):
55
56
57
58
59
60
61
62
    """Create a Tensor Memory Access (TMA) descriptor.

    Args:
        *args: Variable arguments defining the TMA descriptor configuration

    Returns:
        tir.Call: A handle to the created TMA descriptor
    """
63
    return tir.call_intrin("handle", tir.op.Op.get("tl.create_tma_descriptor"), *args)
64
65


66
def tma_load(*args):
67
68
69
70
71
72
73
74
    """Perform a Tensor Memory Access (TMA) load operation.

    Args:
        *args: Variable arguments specifying the TMA load parameters

    Returns:
        tir.Call: A handle to the TMA load operation
    """
75
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_load"), *args)
76
77


78
def fence_proxy_async(*args):
79
80
81
82
83
84
85
86
    """Create a fence for asynchronous proxy operations.

    Args:
        *args: Variable arguments for fence configuration

    Returns:
        tir.Call: A handle to the fence operation
    """
87
    return tir.call_intrin("handle", tir.op.Op.get("tl.fence_proxy_async"), *args)
88
89


90
def tma_store_arrive(*args):
91
92
93
94
95
96
97
98
    """Signal the arrival of a TMA store operation.

    Args:
        *args: Variable arguments for the store arrival operation

    Returns:
        tir.Call: A handle to the store arrive operation
    """
99
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_arrive"), *args)
100
101


102
def tma_store_wait(*args):
103
104
105
106
107
108
109
110
    """Wait for completion of TMA store operations.

    Args:
        *args: Variable arguments specifying which store operations to wait for

    Returns:
        tir.Call: A handle to the store wait operation
    """
111
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args)
112
113


114
def set_max_nreg(*args):
115
116
117
118
119
120
121
122
    """Set the maximum number of registers to use.

    Args:
        *args: Variable arguments specifying register allocation limits

    Returns:
        tir.Call: A handle to the register setting operation
    """
123
    return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), *args)
124
125


126
def no_set_max_nreg(*args):
127
128
129
130
131
132
133
134
    """Disable the maximum register limit setting.

    Args:
        *args: Variable arguments for the operation

    Returns:
        tir.Call: A handle to the register limit disable operation
    """
135
    return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"), *args)
136
137


138
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
139
140
141
    """Wait for memory barrier parity condition.

    Args:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        mbarrier: Optional[int, PrimExpr]
            The memory barrier to wait on
        parity: Optional[int, Var]
            The parity value to wait for
    Examples:
        .. code-block:: python

            # Wait for parity 0 on barrier 0
            T.mbarrier_wait_parity(0, 0)

            # Wait for parity value in variable ko on barrier 1
            T.mbarrier_wait_parity(1, ko)

            # Wait using barrier handle
            barrier = T.get_mbarrier(0)
            T.mbarrier_wait_parity(barrier, 1)

            # Common usage in pipelined kernels:
            for ko in range(num_stages):
                # Producer waits for consumer to finish previous iteration
                T.mbarrier_wait_parity(1, ko ^ 1)
                # Producer copies data
                T.copy(A_global, A_shared)
                # Producer signals data ready
                T.mbarrier_arrive(0)

                # Consumer waits for producer data
                T.mbarrier_wait_parity(0, ko)
                # Consumer computes
                T.gemm(A_shared, B_shared, C_local)
                # Consumer signals completion
                T.mbarrier_arrive(1)
174
175
176
    Returns:
        tir.Call: A handle to the barrier wait operation
    """
177
178
179
    if isinstance(mbarrier, tir.Call):
        mbarrier = mbarrier
    elif isinstance(mbarrier, (tir.PrimExpr, int)):
180
        mbarrier = get_mbarrier(mbarrier)
181
182
    else:
        raise TypeError("mbarrier must be an integer or a tir.Call")
183
184
185
    return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)


186
def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
187
188
189
190
191
192
    """Arrive at memory barrier.

    Args:
        mbarrier: Optional[int, PrimExpr]
            The memory barrier to arrive at
    """
193
194
195
    if isinstance(mbarrier, tir.Call):
        mbarrier = mbarrier
    elif isinstance(mbarrier, (tir.PrimExpr, int)):
196
        mbarrier = get_mbarrier(mbarrier)
197
198
    else:
        raise TypeError("mbarrier must be an integer or a tir.Call")
199
    return ptx_arrive_barrier(mbarrier)
200
201


202
def mbarrier_expect_tx(*args):
203
204
205
206
207
208
209
210
    """Set expected transaction count for memory barrier.

    Args:
        *args: Variable arguments specifying the expected transaction count

    Returns:
        tir.Call: A handle to the barrier expectation operation
    """
211
    return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
212
213


214
def wait_wgmma(*args):
215
216
217
218
219
220
221
222
    """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

    Args:
        *args: Variable arguments specifying which operations to wait for

    Returns:
        tir.Call: A handle to the WGMMA wait operation
    """
223
    return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)