builtin.cc 14.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*!
 * \file tl/op/builtin.cc
 * \brief Builtin intrinsics.
 *
 */

#include "builtin.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/cuda.h"
14
#include "../target/utils.h"
15
16
17
18

namespace tvm {
namespace tl {

19
TVM_REGISTER_PASS_CONFIG_OPTION(kDebugMergeSharedMemoryAllocations, Bool);
20
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
21
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableSafeMemoryLegalize, Bool);
22
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
23
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableThreadStorageSync, Bool);
24
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
25
26
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
27
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
28
TVM_REGISTER_PASS_CONFIG_OPTION(kForceLetInline, Bool);
29
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
30
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
31
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
32
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
33
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
34
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
35
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
36
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
37
38
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
39
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
40

41
42
DataType cuTensorMapType() { return DataType::UInt(8, 128); }

43
44
45
46
47
48
49
#define TIR_DEFINE_TL_BUILTIN(OpName)                                          \
  const Op &OpName() {                                                         \
    static const Op &op = Op::Get("tl." #OpName);                              \
    return op;                                                                 \
  }                                                                            \
  TVM_REGISTER_OP("tl." #OpName)                                               \
      .set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName)
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// fast math related op
TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// high precision with IEEE-compliant
TIR_DEFINE_TL_BUILTIN(ieee_add).set_num_inputs(3).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_sub).set_num_inputs(3).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_mul).set_num_inputs(3).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_fmaf).set_num_inputs(4).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_frcp).set_num_inputs(2).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_fsqrt)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

105
106
107
108
109
110
TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

111
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
112
    .set_num_inputs(-1)
113
114
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
115

116
TIR_DEFINE_TL_BUILTIN(create_tma_descriptor)
117
    .set_num_inputs(-1)
118
119
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));
120

121
TIR_DEFINE_TL_BUILTIN(create_tma_im2col_descriptor)
122
    .set_num_inputs(-1)
123
124
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));
125

126
TIR_DEFINE_TL_BUILTIN(get_mbarrier)
127
    .set_num_inputs(1)
128
129
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));
130

131
TIR_DEFINE_TL_BUILTIN(tma_load).set_num_inputs(-1).set_attr<TCallEffectKind>(
132
133
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

134
TIR_DEFINE_TL_BUILTIN(tma_load_im2col)
135
136
137
138
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

139
140
TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
141

142
143
144
145
146
TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

147
TIR_DEFINE_TL_BUILTIN(mbarrier_wait_parity)
148
    .set_num_inputs(2)
149
150
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
151

152
TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
153
    .set_num_inputs(2)
154
155
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
156

157
158
159
160
161
162
163
164
165
166
TIR_DEFINE_TL_BUILTIN(ptx_wgmma_ss)
    .set_num_inputs(15)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_wgmma_rs)
    .set_num_inputs(15)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

167
168
169
170
171
172
173
174
175
176
TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ss)
    .set_num_inputs(14)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_tcgen05_mma_ts)
    .set_num_inputs(13)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

177
178
179
180
181
182
183
184
185
186
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

187
188
189
190
191
TIR_DEFINE_TL_BUILTIN(ptx_mma_sm70)
    .set_num_inputs(13)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

192
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
193
    .set_num_inputs(4)
194
195
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
196

197
TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
198
    .set_num_inputs(-1)
199
200
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
201

202
203
204
205
206
TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

207
TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
208
    .set_num_inputs(0)
209
210
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
211

212
TIR_DEFINE_TL_BUILTIN(tma_store_arrive)
213
214
215
216
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

217
TIR_DEFINE_TL_BUILTIN(tma_store_wait)
218
219
220
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
221
TIR_DEFINE_TL_BUILTIN(set_max_nreg)
222
    .set_num_inputs(2)
223
224
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
225

226
TIR_DEFINE_TL_BUILTIN(no_set_max_nreg)
227
228
229
230
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
TIR_DEFINE_TL_BUILTIN(warpgroup_arrive)
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warpgroup_commit_batch)
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warpgroup_wait)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

246
247
248
249
250
TIR_DEFINE_TL_BUILTIN(warpgroup_fence_operand)
    .set_num_inputs(4)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
TIR_DEFINE_TL_BUILTIN(get_lane_idx)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_idx_sync)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_idx)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

TIR_DEFINE_TL_BUILTIN(get_warp_group_idx)
    .set_num_inputs(-1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

271
272
273
274
TIR_DEFINE_TL_BUILTIN(wait_wgmma)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
275

276
TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr<TCallEffectKind>(
277
    "TCallEffectKind", Integer(CallEffectKind::kPure));
278
279
280
281
282
283
284
285

TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(loop_break)
    .set_num_inputs(0)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
286
287
288
289
290
291
292
293
294

TIR_DEFINE_TL_BUILTIN(tl_gemm).set_num_inputs(4).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
    .set_num_inputs(5)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
TIR_DEFINE_TL_BUILTIN(tvm_mfma).set_num_inputs(12).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tvm_mfma_store)
    .set_num_inputs(6)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma)
    .set_num_inputs(12)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(tvm_rdna_wmma_store)
    .set_num_inputs(6)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

313
314
315
316
317
TIR_DEFINE_TL_BUILTIN(tl_shuffle_elect)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kPure));

318
TIR_DEFINE_TL_BUILTIN(initialize_wgmma_descriptor)
319
320
321
322
    .set_num_inputs(5)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

323
324
325
326
327
TIR_DEFINE_TL_BUILTIN(initialize_tcgen05_descriptor)
    .set_num_inputs(7)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

328
329
330
331
332
TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

333
TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
334
    .set_num_inputs(3)
335
336
337
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

338
339
340
341
342
343
344
345
346
347
TIR_DEFINE_TL_BUILTIN(device_assert)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
    .set_num_inputs(2)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

348
349
350
351
352
TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

Tong WU's avatar
Tong WU committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
TIR_DEFINE_TL_BUILTIN(warp_reduce_sum)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warp_reduce_max)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warp_reduce_min)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor)
    .set_num_inputs(1)
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));

378
379
380
381
382
// __ldg(BufferLoad | Buffer, idx?) -> value
// Treat as a pure call that returns the loaded value.
TIR_DEFINE_TL_BUILTIN(__ldg).set_num_inputs(-1).set_attr<TCallEffectKind>(
    "TCallEffectKind", Integer(CallEffectKind::kPure));

383
} // namespace tl
384
} // namespace tvm