builtin.py 9.66 KB
Newer Older
1
2
"""The language interface for tl programs."""

3
from tilelang import tvm as tvm
4
5
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
6
from tvm import tir
7
8
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call
9
10


11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def create_list_of_mbarrier(*args: Any) -> Call:
    """
    Create a list of memory barrier handles.

    Parameters
    ----------
    *args : list or Any
        Either a single list of arguments, or multiple arguments directly.

    Returns
    -------
    tvm.tir.Call
        Handle to the created list of memory barriers.

    Raises
    ------
    TypeError
        If the input is not a list or variadic arguments.
    
    Examples
    --------
    >>> create_list_of_mbarrier([128, 128])
    >>> create_list_of_mbarrier(128, 128)
34
    """
35
36
37
38
39
40
    if len(args) == 1 and isinstance(args[0], list):
        return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args[0])
    elif len(args) >= 1:
        return tir.call_intrin("handle", tir.op.Op.get("tl.create_list_of_mbarrier"), *args)
    else:
        raise TypeError("create_list_of_mbarrier expects a list or one or more arguments.")
41
42


43
def get_mbarrier(*args):
44
45
46
47
48
49
50
51
    """Retrieve a memory barrier operation.

    Args:
        *args: Variable arguments to specify which memory barrier to retrieve

    Returns:
        tir.Call: A handle to the requested memory barrier
    """
52
    return tir.call_intrin("handle", tir.op.Op.get("tl.get_mbarrier"), *args)
53
54


55
def create_tma_descriptor(*args):
56
57
58
59
60
61
62
63
    """Create a Tensor Memory Access (TMA) descriptor.

    Args:
        *args: Variable arguments defining the TMA descriptor configuration

    Returns:
        tir.Call: A handle to the created TMA descriptor
    """
64
    return tir.call_intrin("handle", tir.op.Op.get("tl.create_tma_descriptor"), *args)
65
66


67
def tma_load(*args):
68
69
70
71
72
73
74
75
    """Perform a Tensor Memory Access (TMA) load operation.

    Args:
        *args: Variable arguments specifying the TMA load parameters

    Returns:
        tir.Call: A handle to the TMA load operation
    """
76
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_load"), *args)
77
78


79
def fence_proxy_async(*args):
80
81
82
83
84
85
86
87
    """Create a fence for asynchronous proxy operations.

    Args:
        *args: Variable arguments for fence configuration

    Returns:
        tir.Call: A handle to the fence operation
    """
88
    return tir.call_intrin("handle", tir.op.Op.get("tl.fence_proxy_async"), *args)
89
90


91
def tma_store_arrive(*args):
92
93
94
95
96
97
98
99
    """Signal the arrival of a TMA store operation.

    Args:
        *args: Variable arguments for the store arrival operation

    Returns:
        tir.Call: A handle to the store arrive operation
    """
100
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_arrive"), *args)
101
102


103
def tma_store_wait(*args):
104
105
106
107
108
109
110
111
    """Wait for completion of TMA store operations.

    Args:
        *args: Variable arguments specifying which store operations to wait for

    Returns:
        tir.Call: A handle to the store wait operation
    """
112
    return tir.call_intrin("handle", tir.op.Op.get("tl.tma_store_wait"), *args)
113
114


115
def set_max_nreg(reg_count: int, is_inc: int):
116
    """Set the maximum number of registers to use.
117
118
    Detailed Documentation:
    https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
119
120

    Args:
121
122
123
124
125
        reg_count: int
            The number of registers to allocate
        is_inc: int
            Whether to increment or decrement the register count
            0 if decrement, 1 if increment
126
127
128
129

    Returns:
        tir.Call: A handle to the register setting operation
    """
130
    return tir.call_intrin("handle", tir.op.Op.get("tl.set_max_nreg"), reg_count, is_inc)
131
132


133
134
135
136
def inc_max_nreg(reg_count: int):
    """Increment the maximum number of registers to use.
    """
    return set_max_nreg(reg_count, 1)
137
138


139
140
141
142
143
144
145
146
def dec_max_nreg(reg_count: int):
    """Decrement the maximum number of registers to use.
    """
    return set_max_nreg(reg_count, 0)


def no_set_max_nreg():
    """Disable the maximum register limit setting.
147
    """
148
    return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
149
150


151
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
152
153
154
    """Wait for memory barrier parity condition.

    Args:
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        mbarrier: Optional[int, PrimExpr]
            The memory barrier to wait on
        parity: Optional[int, Var]
            The parity value to wait for
    Examples:
        .. code-block:: python

            # Wait for parity 0 on barrier 0
            T.mbarrier_wait_parity(0, 0)

            # Wait for parity value in variable ko on barrier 1
            T.mbarrier_wait_parity(1, ko)

            # Wait using barrier handle
            barrier = T.get_mbarrier(0)
            T.mbarrier_wait_parity(barrier, 1)

            # Common usage in pipelined kernels:
            for ko in range(num_stages):
                # Producer waits for consumer to finish previous iteration
                T.mbarrier_wait_parity(1, ko ^ 1)
                # Producer copies data
                T.copy(A_global, A_shared)
                # Producer signals data ready
                T.mbarrier_arrive(0)

                # Consumer waits for producer data
                T.mbarrier_wait_parity(0, ko)
                # Consumer computes
                T.gemm(A_shared, B_shared, C_local)
                # Consumer signals completion
                T.mbarrier_arrive(1)
187
188
189
    Returns:
        tir.Call: A handle to the barrier wait operation
    """
190
191
192
    if isinstance(mbarrier, tir.Call):
        mbarrier = mbarrier
    elif isinstance(mbarrier, (tir.PrimExpr, int)):
193
        mbarrier = get_mbarrier(mbarrier)
194
195
    else:
        raise TypeError("mbarrier must be an integer or a tir.Call")
196
197
198
    return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)


199
def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
200
201
202
203
204
205
    """Arrive at memory barrier.

    Args:
        mbarrier: Optional[int, PrimExpr]
            The memory barrier to arrive at
    """
206
207
208
    if isinstance(mbarrier, tir.Call):
        mbarrier = mbarrier
    elif isinstance(mbarrier, (tir.PrimExpr, int)):
209
        mbarrier = get_mbarrier(mbarrier)
210
211
    else:
        raise TypeError("mbarrier must be an integer or a tir.Call")
212
    return ptx_arrive_barrier(mbarrier)
213
214


215
def mbarrier_expect_tx(*args):
216
217
218
219
220
221
222
223
    """Set expected transaction count for memory barrier.

    Args:
        *args: Variable arguments specifying the expected transaction count

    Returns:
        tir.Call: A handle to the barrier expectation operation
    """
224
    return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_expect_tx"), *args)
225
226


227
def wait_wgmma(*args):
228
229
230
231
232
233
234
235
    """Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

    Args:
        *args: Variable arguments specifying which operations to wait for

    Returns:
        tir.Call: A handle to the WGMMA wait operation
    """
236
    return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295


def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None):
    """Wait for a memory barrier to complete.

    Args:
        barrier_id: Optional[int, PrimExpr]
            The memory barrier to wait on
        parity: Optional[int, Var]
            The parity value to wait for
    Returns:
        tir.Call: A handle to the barrier wait operation
    Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.
    """
    return mbarrier_wait_parity(barrier_id, parity)


def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
    """Arrive at a memory barrier.

    Args:
        barrier_id: Optional[int, PrimExpr]
            The memory barrier to arrive at
    """
    return mbarrier_arrive(barrier_id)


def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
    """Perform a shuffle operation with XOR offset.

    Args:
        value: Optional[int, PrimExpr]
            The value to shuffle
        offset: Optional[int, PrimExpr]
            The offset for the shuffle operation
    Returns:
        tir.Call: A handle to the shuffle operation
    """
    return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)


def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
    """Perform a shuffle operation with down offset.

    Args:
        value: Optional[int, PrimExpr]
            The value to shuffle
    """
    return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)


def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
    """Perform a shuffle operation with up offset.

    Args:
        value: Optional[int, PrimExpr]
            The value to shuffle
    """
    return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314


def sync_threads():
    """Synchronize all threads in a warp.
    """
    return tir.op.tvm_storage_sync("shared")


def sync_thread_partial(barrier_id: Union[int, PrimExpr, tir.Call]):
    """Synchronize threads within a warp.

    Args:
        barrier_id: Optional[int, PrimExpr]
            The memory barrier to synchronize

    Returns:
        tir.Call: A handle to the synchronization operation
    """
    return tir.call_intrin("handle", tir.op.Op.get("tl.sync_thread_partial"), barrier_id)
315
316
317
318
319
320
321
322
323
324


def sync_global():
    """Synchronize all threads in a block.
    """
    tx, ty, tz = get_thread_bindings()
    ex, ey, ez = get_block_extents()
    print(tx, ty, tz, ex, ey, ez)
    args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez]
    return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))