builtin.h 8.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * \file tl/op/builtin.h
 * \brief Builtin intrinsics.
 *
 */

#ifndef TVM_TL_OP_BUILTIN_H_
#define TVM_TL_OP_BUILTIN_H_

10
#include "operator.h"
11
#include <tvm/ir/transform.h>
12
13

namespace tvm {
14
15
16
17
18
19
20
21
/*!
 * \brief Create the TVM intrinsic that initializes a PTX fence barrier.
 *
 * Initializes a PTX fence-style barrier used to coordinate asynchronous memory
 * operations (for example, TMA/TMA_STORE). Returns the Op representing this
 * intrinsic for use in TIR lowering and code generation.
 *
 */
22
namespace tl {
23
24
25

namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
26
27
static constexpr const char *kWarpSpecializationScope =
    "kWarpSpecializationScope";
28
29
static constexpr const char *kCustomWarpSpecialization =
    "kCustomWarpSpecialization";
30
31
} // namespace attr

32
33
static constexpr const char *kDebugMergeSharedMemoryAllocations =
    "tl.debug_merge_shared_memory_allocations";
34
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
35
36
static constexpr const char *kDisableSafeMemoryLegalize =
    "tl.disable_safe_memory_legalize";
37
38
static constexpr const char *kDisableWarpSpecialized =
    "tl.disable_warp_specialized";
39
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
40
41
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
    "tl.enable_aggressive_shared_memory_merge";
42
static constexpr const char *kDisableFastMath = "tl.disable_fast_math";
43
static constexpr const char *kEnableFastMath = "tl.enable_fast_math";
44
45
static constexpr const char *kPtxasRegisterUsageLevel =
    "tl.ptxas_register_usage_level";
46
47
static constexpr const char *kEnablePTXASVerboseOutput =
    "tl.enable_ptxas_verbose_output";
48
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
49
50
51
52
53
54
55
56
57
/*!
 * \brief Whether to disable dynamic tail split
 *
 * kDisableDynamicTailSplit = "tl.disable_dynamic_tail_split"
 *
 */
static constexpr const char *kDisableDynamicTailSplit =
    "tl.disable_dynamic_tail_split";

58
59
60
61
62
63
64
65
66
67
68
69
70
71
/*!
 * \brief Whether to disable thread storage synchronization
 *
 * When enabled, disables the automatic insertion of thread synchronization
 * barriers (e.g., __syncthreads()) for shared memory access coordination.
 * This can be useful for performance optimization in cases where manual
 * synchronization is preferred or when synchronization is not needed.
 *
 * kDisableThreadStorageSync = "tl.disable_thread_storage_sync"
 *
 */
static constexpr const char *kDisableThreadStorageSync =
    "tl.disable_thread_storage_sync";

72
73
74
75
76
77
78
79
80
81
82
83
/*!
 * \brief The size of the vectorized dimension in buffer, designed by user
 *
 * For example, if the vectorized dimension is 128 bits and the dtype of buffer
 * A[m, k] is float16, the size of the vectorized dimension (i.e. k) in buffer A
 * should be divisible by 8 (8 = 128 / 16).
 *
 * kDynamicAlignment = "tl.dynamic_alignment"
 *
 */
static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";

84
85
86
87
88
89
90
91
/*!
 * \brief Get the type of the CUDA tensor map
 *
 * DataType cuTensorMapType()
 *
 */
DataType cuTensorMapType();

92
93
94
/*!
 * \brief tvm intrinsics for TMADescriptor creation for tiled load
 *
95
 * CuTensorMap* create_tma_descriptor(data_type, rank, global_addr,
96
97
 * global_shape..., global_stride..., smem_box..., smem_stride..., interleave,
 * swizzle, l2_promotion, oob_fill)
98
99
 *
 */
100
TVM_DLL const Op &create_tma_descriptor();
101
102
103
104

/*!
 * \brief tvm intrinsics for TMADescriptor creation for image to column load
 *
105
 * CuTensorMap* create_tma_im2col_descriptor(data_type, rank, global_addr,
106
107
108
 * global_shape..., global_stride..., elem_stride..., lower_corner...,
 * upper_corner..., smme_box_pixel, smem_box_channel, interleave, swizzle,
 * l2_promotion, oob_fill)
109
110
 *
 */
111
TVM_DLL const Op &create_tma_im2col_descriptor();
112
113
114
115

/*!
 * \brief Create a list of mbarrier with num_threads
 *
116
 * create_list_of_mbarrier(num_threads0, num_threads1, ...)
117
118
 *
 */
119
TVM_DLL const Op &create_list_of_mbarrier();
120
121
122
123
124
125
126

/*!
 * \brief Get the mbarrier with barrier_id
 *
 * int64_t* GetMBarrier(barrier_id)
 *
 */
127
TVM_DLL const Op &get_mbarrier();
128
129

/*!
130
131
 * \brief tvm intrinsics for loading data from global tensor descriptor to
 * shared memory
132
 *
133
 * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
134
135
 *
 */
136
TVM_DLL const Op &tma_load();
137
138

/*!
139
140
 * \brief tvm intrinsics for loading image from global tensor to columns in
 * shared memory
141
 *
142
 * tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...,
143
 * image_offset, ...)
144
145
 *
 */
146
TVM_DLL const Op &tma_load_im2col();
147
148

/*!
149
150
 * \brief tvm intrinsics for storing data from shared memory to global tensor
 * descriptor
151
 *
152
 * tma_store(descriptor, smem_data, coord_0, coord_1, ...)
153
154
 *
 */
155
TVM_DLL const Op &tma_store();
156

157
158
159
160
161
162
163
164
/*!
 * \brief tvm intrinsics for barrier initialization fence
 *
 * ptx_fence_barrier_init()
 *
 */
const Op &ptx_fence_barrier_init();

165
166
167
/*!
 * \brief tvm intrinsics for mbarrier wait with parity bit
 *
168
 * mbarrier_wait_parity(mbarrier, parity)
169
170
 *
 */
171
TVM_DLL const Op &mbarrier_wait_parity();
172
173
174
175

/*!
 * \brief tvm intrinsics for mbarrier expect tx
 *
176
 * mbarrier_expect_tx(mbarrier, transaction_bytes)
177
178
 *
 */
179
TVM_DLL const Op &mbarrier_expect_tx();
180
181
182
183

/*!
 * \brief tvm intrinsics for ldmatrix
 *
184
 * ptx_ldmatrix(transposed, num, shared_addr, local_addr)
185
186
 *
 */
187
TVM_DLL const Op &ptx_ldmatrix();
188
189
190
191

/*!
 * \brief tvm intrinsics for stmatrix
 *
192
 * ptx_ldmatrix(transposed, num, shared_addr, int32_values...)
193
194
 *
 */
195
TVM_DLL const Op &ptx_stmatrix();
196

197
198
199
200
201
202
203
204
/*!
 * \brief tvm intrinsic for ptx async copy barrier using
 * cp.async.mbarrier.arrive.noinc
 *
 *  This op is used to represent a ptx async copy barrier operation in tilelang.
 */
TVM_DLL const Op &ptx_cp_async_barrier_noinc();

205
206
207
/*!
 * \brief Pack two b16 value into a b32 value
 *
208
 * int32 pack_b16(b16_value, b16_value)
209
210
 *
 */
211
TVM_DLL const Op &pack_b16();
212
213
214
215
216
217
218

/*!
 * \brief Issue a shared memory fence for async operations
 *
 * FenceProxyAsync()
 *
 */
219
TVM_DLL const Op &fence_proxy_async();
220

221
222
223
/*!
 * \brief Indicate arrival of warp issuing TMA_STORE
 *
224
 * tma_store_arrive()
225
226
 *
 */
227
TVM_DLL const Op &tma_store_arrive();
228
229
230
231

/*!
 * \brief Wait for TMA_STORE to finish
 *
232
 * tma_store_wait()
233
234
 *
 */
235
TVM_DLL const Op &tma_store_wait();
236

237
238
239
240
241
242
/*!
 * \brief Set reg hint for warp-specialized branched
 *
 * SetMaxNRegInc(num_reg, is_inc)
 *
 */
243
TVM_DLL const Op &set_max_nreg();
244

245
246
247
/*!
 * \brief No set reg hint for warp-specialized branched
 *
248
 * no_set_max_nreg()
249
250
 *
 */
251
TVM_DLL const Op &no_set_max_nreg();
252

253
254
255
/*!
 * \brief Wait the previous wgmma to finish
 *
256
 * wait_wgmma(num_mma)
257
258
 *
 */
259
TVM_DLL const Op &wait_wgmma();
260

261
262
263
264
265
266
/*!
 * \brief Synchronize all threads in a grid
 *
 * sync_grid()
 *
 */
267
TVM_DLL const Op &sync_grid();
268
269
270
271
272
273
274

/*!
 * \brief tvm intrinsic for loop continue
 *
 * loop_break()
 *
 */
275
TVM_DLL const Op &loop_break();
276

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
/*!
 * \brief tvm intrinsic for amd matrix core mfma instructions.
 *
 *  void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout,
 *               StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
 *               Var multiplicand_a, Expr a_index,
 *               Var multiplicand_b, Expr b_index,
 *               Var accumulator, Expr c_index);
 */
TVM_DLL const Op &tvm_mfma();

/*!
 * \brief tvm intrinsic for storing the result of AMD MFMA into a destination
 * pointer.
 *
 *        There is no real instruction that does that, but we want to hide
 * details of complex index manipulation behind this intrinsic to simplify TIR
 * lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does.
 *
 * void tvm_mfma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr
 * src_offset, Var dst_stride);
 */
TVM_DLL const Op &tvm_mfma_store();

/*!
 * \brief tvm intrinsic for amd rdna matrix core instructions.
 *
 *  void tvm_rdna_wmma(StringImm shape, StringImm A_layout, StringImm B_layout,
 *               StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
 *               Var multiplicand_a, Expr a_index,
 *               Var multiplicand_b, Expr b_index,
 *               Var accumulator, Expr c_index);
 */
TVM_DLL const Op &tvm_rdna_wmma();

/*!
 * \brief tvm intrinsic for storing the result of AMD RDNA WMMA into a
 * destination pointer.
 *
 *        There is no real instruction that does that, but we want to hide
 * details of complex index manipulation behind this intrinsic to simplify TIR
 * lowering passes (e.g. LowerWarpMemory) like cuda ptx backend does.
 *
 * void tvm_rdna_wmma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr
 * src_offset, Var dst_stride);
 */
TVM_DLL const Op &tvm_rdna_wmma_store();

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
/*!
 * \brief tilelang intrinsic for general matrix multiplication (GEMM).
 *
 *  This op is used to represent a generic GEMM operation in tilelang.
 */
TVM_DLL const Op &tl_gemm();

/*!
 * \brief tilelang intrinsic for sparse matrix multiplication (GEMM with
 * sparsity).
 *
 *  This op is used to represent a sparse GEMM operation in tilelang.
 */
TVM_DLL const Op &tl_gemm_sp();

340
341
342
343
344
345
346
/*!
 * \brief tilelang intrinsic for shuffle elect.
 *
 *  This op is used to represent a shuffle elect operation in tilelang.
 */
TVM_DLL const Op &tl_shuffle_elect();

347
348
} // namespace tl
} // namespace tvm
349

350
#endif //  TVM_TL_OP_BUILTIN_H_