codegen_cuda.cc 79.1 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file target/codegen.cc
 */

#include "codegen_cuda.h"
#include <tvm/arith/analyzer.h>
7
#include <tvm/ffi/function.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
17
#include "arith/pattern_match.h"
18
19
20
21
22
#include "target/source/ptx.h"

namespace tvm {
namespace codegen {

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
static std::string GetFP8Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "_2";
  } else if (lanes == 4) {
    vec = "_4";
  } else if (lanes == 8) {
    vec = "_8";
  } else if (lanes == 16) {
    vec = "_16";
  } else {
    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
                  "for FP8";
  }
41
42
  if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
      type.is_float8_e4m3()) {
43
    stream << "fp8_e4" << vec << "_t";
44
45
  } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
             type.is_float8_e5m2()) {
46
47
    stream << "fp8_e5" << vec << "_t";
  } else {
48
    LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
49
50
51
52
  }
  return stream.str();
}

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
std::string GetFP6Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP6";
  }
  stream << "__nv_fp6";
  std::string suffix;
  if (type.code() == DataType::kFloat6_e2m3fn) {
    suffix = "_e2m3";
  } else if (type.code() == DataType::kFloat6_e3m2fn) {
    suffix = "_e3m2";
  } else {
    LOG(FATAL) << "Unsupported FP6 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

std::string GetFP4Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP4";
  }
  stream << "__nv_fp4";
  std::string suffix;
  if (type.code() == DataType::kFloat4_e2m1fn) {
    suffix = "_e2m1";
  } else {
    LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

113
114
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
  restrict_keyword_ = "__restrict__";
115
116
117
118
119
  vid_global_barrier_state_ =
      name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
  vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
  ICHECK_EQ(vid_global_barrier_state_,
            runtime::symbol::tvm_global_barrier_state);
120
}
121

122
123
124
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
125
126

class LaunchConfigExtractor : public tir::StmtVisitor {
127
128
private:
  void VisitStmt_(const AttrStmtNode *op) final {
129
130
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
131
132
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
133
        threadIdx_x_ext = op->value;
134
135
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
136
        threadIdx_y_ext = op->value;
137
138
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
139
140
141
142
143
144
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

145
public:
146
147
148
149
150
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

151
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) {
152
153
154
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
155
156
157
158
159
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
160
    if (threadIdx_ext_int->value == 1) {
161
162
      // unable to extract the number of threads per block, hence directly
      // return
163
164
      return;
    }
165
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)";
166
167
168
169
170
171
172
  }
}

std::string CodeGenTileLangCUDA::Finish() {
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
173
174
175
176
177
178
179
180
  if (enable_fp8_) {
    decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
  }

  if (need_math_constants_h_) {
    decl_stream << "#include <math_constants.h>\n";
  }

181
182
183
184
  if (need_cooperative_groups_) {
    decl_stream << "#include <cooperative_groups.h>\n";
  }

185
  decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
186
187
188
  if (enable_sparse_gemm_) {
    decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
  }
189
190
191
192
  decl_stream << "#include <tl_templates/cuda/copy.h>\n";
  decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
  decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
  decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
193
  decl_stream << "#include <tl_templates/cuda/debug.h>\n";
194
195
196
  decl_stream << "#ifdef ENABLE_BF16\n";
  decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n";
  decl_stream << "#endif\n";
197
198

  if (need_global_barrier_) {
199
200
    decl_stream << "__device__ unsigned " << vid_global_barrier_state_
                << " = 0;\n";
201
  }
202
  decl_stream << "\n";
203

204
205
206
  return CodeGenC::Finish();
}

207
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
208
209
210
211
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
212
213
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
214
215
216
217
218
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
219
220
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
221
222
223
224
225
226
227
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

228
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
229
  ICHECK(!var_idmap_.count(iv->var.get()));
230
231
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
232
233
}

234
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
255
    case 16:
256
      enable_fp16_ = true;
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
      } else {
272
        fail = true;
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
299
    }
300
301
302
303
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
304
305
306
307
308
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
309
    enable_bf16_ = true;
310
311
312
313
314
315
316
317
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
    } else {
      fail = true;
    }
318
319
    if (!fail)
      return;
320
  } else if (t.is_float8()) {
321
322
323
    enable_fp8_ = true;
    os << GetFP8Type(t);
    return;
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
  } else if (t.is_float6()) {
    enable_fp6_ = true;
    if (t.lanes() <= 4) {
      os << GetFP6Type(t);
    } else {
      fail = true;
    }
    return;
  } else if (t.is_float4()) {
    enable_fp4_ = true;
    if (t.lanes() <= 4) {
      os << GetFP4Type(t);
    } else {
      fail = true;
    }
    return;
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
371
      }
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
395
      }
396
397
398
399
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.
400
        enable_int8_ = true;
401
402
403
404
405
406
407

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
408
        enable_int8_ = true;
409
410
411
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
412
        enable_int8_ = true;
413
414
415
416
        os << "int4";
        return;
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
417
        break;
418
419
      } else {
        os << "char";
420
421
        break;
      }
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
445
446
        return;
      }
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
466
        fail = true;
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
      }
      return;
    }
    default:
      fail = true;
      break;
488
489
490
491
492
493
494
495
496
497
498
499
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

500
501
502
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
                                           PrimExpr lhs, PrimExpr rhs,
                                           std::ostream &os) { // NOLINT(*)
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

536
537
538
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
                                           int i,
                                           std::ostream &os) { // NOLINT(*)
539
540
541
542
543
544
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
545
546
547
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
548
549
550
551
552
553
554
555
556
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
    }
  } else if (t.is_float16()) {
557
558
    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
559
  } else if (t.is_bfloat16()) {
560
561
    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
580
581
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
582
583
584
585
586
  } else {
    os << vec << "." << access[i];
  }
}

587
588
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
                                            int i, const std::string &value) {
589
590
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
591
592
593
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
594
595
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
596
597
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
598
599
600
601
602
603
604
605
606
607
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
    }
  } else if (t.is_float16()) {
608
609
    stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
610
  } else if (t.is_bfloat16()) {
611
612
    stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
631
632
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
633
634
635
636
637
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

638
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
639
640
  auto args = op->args;
  const std::string &sync = args[0].as<StringImmNode>()->value;
641
642
643
644
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    if (args.size() == 1) {
      this->stream << "__syncthreads();\n";
    } else if (args.size() == 2) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n";
    } else if (args.size() == 3) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      auto thread_count = args[2].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ", "
                   << thread_count << ">();\n";
    } else {
      LOG(FATAL) << "Invalid number of arguments for storage sync: "
                 << args.size();
    }
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
  } else if (sync == "global") {
    if (!need_global_barrier_) {
      need_global_barrier_ = true;
    }
    // global synchronizer
    std::string is_load = PrintExpr(op->args[1]);
    std::string num_blocks = PrintExpr(op->args[2]);
    this->PrintIndent();
    // In theory only threadfence is needed
    // but we observed problems with only threadfence
    this->stream << "__threadfence_system();\n";
    this->PrintIndent();
    this->stream << "if (" << is_load << ") {\n";
    int wb = this->BeginScope();
    this->PrintIndent();
    this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
    this->PrintIndent();
    std::string ptr = name_supply_->FreshName("pf");
    this->stream << "volatile unsigned* " << ptr << " = &"
                 << vid_global_barrier_state_ << ";\n";
    this->PrintIndent();
    this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
    this->PrintIndent();
    this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
                 << ");\n";
    this->EndScope(wb);
    this->PrintIndent();
    this->stream << "}\n";
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
689
690
691
  }
}

692
693
694
695
696
void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
                                            std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
697
  if (scope == "shared" || scope == "shared.barrier") {
698
699
700
701
702
703
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

704
705
706
707
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
                                            DataType target) {
  if (from == target)
    return value;
708
709
710
711
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
712
713
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
714
715
716
717
718
719
720
721
722
723
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
  os << value << ")";
  return os.str();
}

724
void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
725
726
727
728
729
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
730
731
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
732
733
734
735
736
737
738

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
  std::string src = SSAGetID(PrintExpr(op->value), from_ty);

  // Handle bfloat16 special cases with supported ops
  bool used_bf16_op = false;
  if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
    std::ostringstream func_name;
    if (from_ty.is_bfloat16())
      func_name << "bf16";
    else if (from_ty.is_float())
      func_name << "float";
    if (from_ty.lanes() > 1)
      func_name << from_ty.lanes();
    func_name << "2";
    if (target_ty.is_bfloat16())
      func_name << "bf16";
    else if (target_ty.is_float())
      func_name << "float";
    else if (target_ty == DataType::Int(16))
      func_name << "int16";
    if (target_ty.lanes() > 1)
      func_name << target_ty.lanes();

    auto fname = func_name.str();
    if (bf16_supported_ops_.count(fname)) {
      used_bf16_op = true;
      stream << "#ifdef ENABLE_BF16\n";
      PrintIndent();
      stream << "reinterpret_cast<";
      if (target_ty.is_bfloat16())
        stream << "__nv_bfloat16";
      else
        PrintType(target_ty.element_of(), stream);
      if (target_ty.lanes() > 1)
        stream << target_ty.lanes();
      stream << " &>(" << sret << ") = fastertransformer::" << fname
             << "(reinterpret_cast<";
      if (from_ty.is_bfloat16())
        stream << "__nv_bfloat16";
      else
        PrintType(from_ty.element_of(), stream);
      if (from_ty.lanes() > 1)
        stream << from_ty.lanes();
      stream << " const &>(" << src << "));\n";
      stream << "#else\n";
783
784
    }
  }
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799

  // Fallback: elementwise cast
  for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
    std::ostringstream val;
    val << "(";
    PrintType(target_ty.element_of(), val);
    val << ")(";
    PrintVecElemLoad(src, from_ty, i, val);
    val << ")";
    PrintVecElemStore(sret, target_ty, i, val.str());
  }

  if (used_bf16_op) {
    stream << "#endif\n";
  }
800
801
802
  os << sret;
}

803
804
805
806
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
                                          const Array<PrimExpr> &args,
                                          bool skip_first_arg,
                                          std::ostream &os) { // NOLINT(*)
807
  DataType ret_dtype = GetRuntimeDataType(ret_type);
808
  if (ret_dtype.is_fixed_length_vector()) {
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
846
847
          if (j > 0)
            scall << ", ";
848
849
850
851
852
853
854
855
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
856
857
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
858
859
860
861
  }
}

// Print a reference expression to a buffer.
862
863
864
865
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
                                              const BufferNode *buffer,
                                              PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }
898
899
900
901
902
903
904
  if (scope.empty()) {
    scope = GetPtrStorageScope(buffer->data);
  }
  if (scope == "local.var") {
    os << vid;
    return os.str();
  }
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
/**
 * @brief Emit CUDA/TensorLib-specific code for a call expression.
 *
 * This visitor handles CallNode intrinsics and builtins that require emitting
 * CUDA/TL-specific code (inline PTX/ASM sequences, TensorLanguage runtime
 * calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based
 * stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The
 * function writes the generated code to the provided output stream and falls
 * back to the C codegen for unrecognized calls.
 *
 * The method recognizes and emits code for (non-exhaustive): cp.async and its
 * commit/wait variants, tma_load/store and im2col variants, ptX
 * ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy
 * MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX
 * asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret
 * paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm
 * and related external calls, and other TL runtime calls.
 *
 * Side effects:
 * - Emits to `os` and the internal codegen output stream.
 * - May set internal feature flags (e.g., need_cooperative_groups_,
 * need_mma_h_, need_cast_smem_ptr_to_int_, enable_sparse_gemm_).
 * - May open/close SSA scopes and mutate internal variable mappings.
 * - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument
 *   patterns.
 *
 * @param op The call node to generate code for; the function inspects op->op
 *           and op->args to determine the appropriate emission.
 * @param os  Output stream to receive expression-level output when the caller
 *            expects an expression result (some paths write directly to the
 *            member stream instead).
 */
959
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
960
961
  auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
                                    size_t end = 0) {
962
963
964
965
    // Cache context into a private ss, otherwise the let node may generate
    // within the function call arguments.
    std::ostringstream ss;

966
967
    for (size_t i = start; i < op->args.size() - end; i++) {
      if (i > start)
968
969
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
970
    }
971
972
973
974

    this->PrintIndent();
    this->stream << name << "(";
    this->stream << ss.str();
975
976
    this->stream << ");\n";
  };
977
978
979
980
981
982
983
984
985
986
987
988
989
  auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
    std::ostringstream ss;
    if (barrier_id.as<IntImmNode>()) {
      // incase the barrier_id is an integer, we need to print the barrier_id as
      // an integer
      ss << mbarrier_name_ << "[" << barrier_id << "]";
    } else {
      // otherwise may be a T.get_mbarrier() call or BufferLoad Node
      // we need to print the barrier_id as a string
      ss << this->PrintExpr(barrier_id);
    }
    return ss.str();
  };
990
991
992
993
994
995
  if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
996
997
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
998
999
    if (op->args.size() == 5) {
      this->PrintIndent();
1000
1001
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
1002
1003
1004
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
1005
1006
1007
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
  } else if (op->op.same_as(builtin::create_barriers())) {
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
1018
1019
    auto mbarrier_storage_name = mbarrier_name_ + "_mem";
    this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "["
1020
                 << barrier_count << "];\n";
1021
1022
1023
    this->PrintIndent();
    this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<"
                 << mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n";
1024
  } else if (op->op.same_as(tl::get_mbarrier())) {
1025
    ICHECK_EQ(op->args.size(), 1);
1026
    std::string barrier_id = this->PrintExpr(op->args[0]);
1027
    os << mbarrier_name_ + "[" + barrier_id + "]";
1028
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    if (op->args.size() == 1) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      this->stream << mbarrier_obj << ".arrive();\n";
    } else if (op->args.size() == 3) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto cta_id = this->PrintExpr(op->args[1]);
      auto pred = this->PrintExpr(op->args[2]);
      this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred
                   << ");\n";
    } else {
      LOG(FATAL) << "Invalid parameter  for tl::arrive_barrier "
                 << op->args.size();
    }
1044
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
1045
1046
1047
1048
1049
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto arrive_count = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n";
1050
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    if (op->args.size() == 2) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto transaction_bytes = this->PrintExpr(op->args[1]);
      this->stream << mbarrier_obj << ".arrive_and_expect_tx("
                   << transaction_bytes << ");\n";
    } else if (op->args.size() == 4) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto transaction_bytes = this->PrintExpr(op->args[1]);
      auto cta_id = this->PrintExpr(op->args[2]);
      auto pred = this->PrintExpr(op->args[3]);
      this->stream << mbarrier_obj << ".arrive_and_expect_tx("
                   << transaction_bytes << ", " << cta_id << ", " << pred
                   << ");\n";
    } else {
      LOG(FATAL) << "Invalid parameter  for tl::arrive_barrier_expect_tx "
                 << op->args.size();
    }
1070
1071
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
1072
  } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
1073
1074
1075
1076
1077
1078
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto transaction_bytes = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes
                 << ");\n";
1079
  } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
1080
1081
1082
1083
1084
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto phase = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
1085
1086
  } else if (op->op.same_as(tl::no_set_max_nreg())) {
    return;
1087
  } else if (op->op.same_as(tl::tma_load())) {
1088
    std::ostringstream ss;
1089
    ICHECK_GE(op->args.size(), 2);
1090
1091
1092
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
1093
1094
1095
1096
1097
1098
    // Simplify the code by using the default eviction policy
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
    } else {
      ss << "tl::tma_load(";
    }
1099
    auto desc = op->args[0];
1100
    ss << this->PrintExpr(desc) << ", ";
1101
    ss << print_mbarrier_obj(op->args[1]) << ", ";
1102
    for (size_t i = 2; i < op->args.size() - 1; i++) {
1103
      if (i > 2)
1104
1105
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
1106
    }
1107
1108
1109
    ss << ");\n";
    this->PrintIndent();
    this->stream << ss.str();
1110
  } else if (op->op.same_as(tl::tma_load_im2col())) {
1111
    std::stringstream ss;
1112
1113
1114
1115
1116
1117
1118
1119
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_load_im2col<tl::CacheHintSm90::" << eviction_policy << ">";
    } else {
      ss << "tl::tma_load_im2col";
    }
1120
    print_extern_call_stmt(ss.str(), 0, 1);
1121
  } else if (op->op.same_as(tl::tma_store())) {
1122
    std::stringstream ss;
1123
1124
1125
1126
1127
1128
1129
1130
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_store<tl::CacheHintSm90::" << eviction_policy << ">";
    } else {
      ss << "tl::tma_store";
    }
1131
    print_extern_call_stmt(ss.str(), 0, 1);
1132
  } else if (op->op.same_as(tl::ptx_ldmatrix())) {
1133
1134
1135
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
1136
1137
    if (trans == 1)
      func_name += "_trans";
1138
    print_extern_call_stmt(func_name, 2);
1139
  } else if (op->op.same_as(tl::ptx_stmatrix())) {
1140
1141
1142
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
1143
1144
    if (trans == 1)
      func_name += "_trans";
1145
    print_extern_call_stmt(func_name, 2);
1146
  } else if (op->op.same_as(tl::fence_proxy_async())) {
1147
    print_extern_call_stmt("tl::fence_proxy_async");
1148
  } else if (op->op.same_as(tl::tma_store_arrive())) {
1149
    print_extern_call_stmt("tl::tma_store_arrive");
1150
  } else if (op->op.same_as(tl::tma_store_wait())) {
1151
    print_extern_call_stmt("tl::tma_store_wait<0>");
1152
  } else if (op->op.same_as(tl::set_max_nreg())) {
1153
1154
1155
    this->PrintIndent();
    int nreg = Downcast<IntImm>(op->args[0])->value;
    int is_inc = Downcast<IntImm>(op->args[1])->value;
1156
1157
    std::string func_name =
        is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
1158
    this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
1159
  } else if (op->op.same_as(tl::wait_wgmma())) {
1160
1161
1162
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
1163
  } else if (op->op.same_as(tl::pack_b16())) {
1164
1165
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
1166
1167
1168
  } else if (op->op.same_as(tl::sync_grid())) {
    this->need_cooperative_groups_ = true;
    this->PrintIndent();
1169
    this->stream << "cooperative_groups::this_grid().sync();\n";
1170
1171
1172
  } else if (op->op.same_as(tl::loop_break())) {
    this->PrintIndent();
    this->stream << "break;\n";
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
1206
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::ptx_mma())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp64, ...
    // arg 4: B precision: fp16, fp64, ...
    // arg 5: C precision: fp32, fp64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index
    // arg 12: saturate
    // arg 13: (optional) 1-bit operator (xor or and)
    ICHECK(op->args.size() == 13U || op->args.size() == 14U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
    bool saturate = Downcast<Bool>(op->args[12])->value;
1261
1262
1263
1264
1265
    std::string bit_op =
        op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
    std::string asm_code = PrintMMAAssembly(
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
        b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302

    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_mma_sp())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp32, ...
    // arg 4: B precision: fp16, fp32, ...
    // arg 5: C precision: fp16, fp32, ...
    // arg 6: A multiplicand pointer
    // arg 7: A multiplicand index
    // arg 8: B multiplicand pointer
    // arg 9: B multiplicand index
    // arg 10: C accumulator pointer
    // arg 11: C accumulator index
    // arg 12: metadata
    // arg 13: metadata index
    // arg 14: sparse_selector
    // arg 15: saturate
    ICHECK_EQ(op->args.size(), 16U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_offset = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    std::string metadata = this->PrintExpr(op->args[12]);
    std::string metadata_offset = this->PrintExpr(op->args[13]);
    std::string sparse_selector = this->PrintExpr(op->args[14]);
    bool saturate = Downcast<Bool>(op->args[15])->value;
    std::string asm_code = PrintMMAAssembly(
1303
1304
1305
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
        b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
        sparse_selector, "", true, saturate);
1306
1307
1308
1309
1310
1311
1312
1313
    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
    // arg 0: whether the matrix is loaded in column major format or not.
    // arg 1: number of matrices to load.
    // arg 2: The data type in the matrix, .b16 is the only accepted data type.
    // arg 3: pointer to local buffer.
    // arg 4: The offset of the element to store in the local buffer.
    // arg 5: pointer to the shared memory buffer to load.
1314
1315
    // arg 6: The offset of the start element of the row to load in shared
    // memory.
1316
1317
1318
1319
1320
1321
1322
1323
    ICHECK_EQ(op->args.size(), 7U);
    bool trans = Downcast<Bool>(op->args[0])->value;
    int num = Downcast<Integer>(op->args[1])->value;
    std::string type = Downcast<StringImm>(op->args[2])->value;
    std::string local_ptr = this->PrintExpr(op->args[3]);
    std::string local_elem_offset = this->PrintExpr(op->args[4]);
    std::string smem_ptr = this->PrintExpr(op->args[5]);
    if (trans && op->dtype.bits() == 8) {
1324
1325
      // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
      // properly transpose an int8 matrix.
1326
1327
1328
1329
      std::string smem_stride = this->PrintExpr(op->args[6]);
      ICHECK(num == 4);
      os << "for (int i = 0; i < 16; ++i) {\n";
      os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
1330
1331
1332
1333
         << "[(i % 8) / 4 * " + smem_stride +
                " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
                "+ (i % 4) * " + smem_stride +
                " + threadIdx.x / 4 +  (i / 8) * 8];\n";
1334
1335
1336
1337
      os << "}\n";
    } else {
      std::string smem_elem_offset = this->PrintExpr(op->args[6]);
      need_cast_smem_ptr_to_int_ = true;
1338
1339
1340
      this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
                                              local_elem_offset, smem_ptr,
                                              smem_elem_offset);
1341
1342
1343
1344
1345
1346
1347
1348
1349
    }
  } else if (op->op.same_as(builtin::mma_store())) {
    int m = Downcast<Integer>(op->args[0])->value;
    int n = Downcast<Integer>(op->args[1])->value;
    std::string dst = this->PrintExpr(op->args[2]);
    std::string src = this->PrintExpr(op->args[3]);
    std::string src_offset = this->PrintExpr(op->args[4]);
    PrimExpr stride = op->args[5];

1350
1351
    ICHECK(m == 16 && n == 16)
        << "Only m == 16 && n == 16 case supported for now";
1352

1353
1354
1355
1356
1357
    // Each thread in a warp holds a certain number of elements of an MMA
    // output. For example, if we compute a 16x16 tile using MMA, each thread
    // holds 8 elements in its registers. So conceptually, a warp memory is
    // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
    // memory is specified by the index map below.
1358

1359
1360
    // To store the 32x8 output back to a 16x16 tile in shared or global memory,
    // we invert this map to determine the output location for each 8 element.
1361

1362
1363
    const auto index_map_func = ffi::Function::GetGlobal(
        "tir.index_map.shared_16x16_to_mma_32x8_layout");
1364

1365
1366
1367
    IndexMap index_map;
    if (!index_map_func) {
      Var i, j;
1368

1369
      // The index map is defined as follows:
1370
1371
1372
1373
1374
      index_map = IndexMap(
          {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
                   4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
    } else {
      index_map = IndexMap::FromFunc(2, *index_map_func);
1375
1376
1377
1378
1379
1380
1381
    }

    arith::Analyzer analyzer;
    auto inverse_index_map =
        index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
    auto indices_16x16 = inverse_index_map->final_indices;

1382
1383
1384
    // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
    // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
    // they reach codegen, so manually replace them to the plain ones here.
1385
    class LowerFloorDivMod : public ExprMutator {
1386
1387
    public:
      PrimExpr VisitExpr_(const FloorDivNode *op) {
1388
1389
        return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
1390
      PrimExpr VisitExpr_(const FloorModNode *op) {
1391
1392
1393
1394
        return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
    };

1395
1396
    auto dst_ind =
        LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
1397
1398
1399
1400
1401
1402
1403
1404
1405

    var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
    var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
    if (op->dtype.bits() == 16) {
      os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n";
      os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])"
         << " = "
         << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
      os << "}\n";
1406
    } else {
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
      os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
      os << dst << "[" + this->PrintExpr(dst_ind) + "]"
         << " = " << src << "[" << src_offset << " + local_id];\n";
      os << "}\n";
    }

  } else if (op->op.same_as(builtin::mma_fill())) {
    std::string num_elem = this->PrintExpr(op->args[0]);
    std::string dst = this->PrintExpr(op->args[1]);
    std::string dst_offset = this->PrintExpr(op->args[2]);

    os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
    os << dst << "[" << dst_offset << " + i] = 0.0;";
    os << "}\n";
  } else if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    need_cast_smem_ptr_to_int_ = true;
1428
1429
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
1430
    if (op->args.size() == 5) {
1431
1432
      this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
                                           size);
1433
    } else {
1434
1435
      this->stream << PrintPredicatedCpAsyncAssembly(
          dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    }
  } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
    need_cast_smem_ptr_to_int_ = true;
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    int barrier_id = Downcast<IntImm>(op->args[5])->value;
    CHECK(barrier_id < barrier_count_);
1446
1447
1448
1449
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
                                        barrier);
1450
1451
1452
1453
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
1454
1455
    this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
                 << ";\");\n\n";
1456
1457
1458
1459
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1460
1461
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1462
1463
1464
1465
1466
1467
    std::string thread_count = this->PrintExpr(op->args[1]);
    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1468
1469
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1470
1471
1472
1473
1474
    this->stream << PrintArriveBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1475
1476
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1477
1478
1479
1480
1481
1482
    std::string byte_count = this->PrintExpr(op->args[1]);
    this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
  } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1483
1484
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    this->stream << PrintWaitBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_ldg32())) {
    /*
    asm volatile (
        "{.reg .pred p;\n"
        " setp.ne.b32 p, %2, 0;\n"
        // " @p ld.global.nc.f32 %0, [%1];}\n"t
        " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
        : "=f"(reg)
        : "l"(addr), "r"((int)guard)
    );
    */

    // get local
    std::string reg = this->PrintExpr(op->args[0]);
    // get guard
    std::string guard = this->PrintExpr(op->args[1]);
1502
    const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
    std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
    std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
    std::string local_addr = this->PrintExpr(op->args[3]);
    this->stream << "asm volatile (\n";
    this->stream << "\"{.reg .pred p;\\n\"\n";
    this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
    this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
    this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
    // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
    stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
           << ")\n";
1514
1515
    stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
           << ")), \"r\"((int)" << guard << ")\n";
1516
    stream << ");\n";
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
  } else if (op->op.same_as(builtin::reinterpret())) {
    DataType tgt_dtype = op->dtype;
    DataType src_dtype = op->args[0]->dtype;
    PrimExpr value = op->args[0];

    // Handle float4_e2m1fn reinterpret
    if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() ==
                                      src_dtype.lanes() * src_dtype.bits()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of lanes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
    CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of bytes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;

    int lanes = tgt_dtype.lanes();

    int ssa_scope = BeginScope();
    if (lanes == 1) {
      // The case of lane=1 is same as the normal reinterpret,
      // except that we allow the src and dst dtype to have different number of
      // bits.
      std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
      os << "(*(";
      this->PrintType(tgt_dtype, os);
      os << " *)(&(" << rhs << ")))";
    } else if (lanes == 2) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint16, and then extract bits of two fp4
        // numbers, and finally reinterpret the result as fp4x2.
        value =
            tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     tir::Cast(DataType::UInt(8),
                               (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                                   ((temp_var >> 4) &
                                    IntImm(DataType::UInt(16), 0xF0))));
      } else {
        value = tir::Cast(
            DataType::UInt(16),
            tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                         ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else if (lanes == 4) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint32, and then extract bits of four fp4
        // numbers, and finally reinterpret the result as fp4x4.
        value =
            tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            tir::Cast(
                DataType::UInt(16),
                (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                    ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
                    ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
                    ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
      } else {
        value = tir::Cast(DataType::UInt(32),
                          tir::Call(DataType::UInt(16),
                                    tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else {
      LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
                 << lanes;
    }
    EndScope(ssa_scope);
  } else if (op->op.same_as(builtin::thread_return())) {
    os << "return";
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
  } else if (op->op.same_as(tl::tl_gemm())) {
    ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
                                    "A_ptr, B_ptr, C_ptr>, but got "
                                 << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
    this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
                          op->args, true, os);
  } else if (op->op.same_as(tl::tl_gemm_sp())) {
    ICHECK(op->args.size() == 5)
        << "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
           "E_ptr>, but got "
        << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
    enable_sparse_gemm_ = true;
    this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
                          op->args, true, os);
1627
1628
  } else if (op->op.same_as(tl::tl_shuffle_elect())) {
    os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
1629
1630
1631
1632
1633
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

1634
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
1635
  if (op->attr_key == tir::attr::fragment_shape) {
1636
1637
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *shape_str = op->value.as<StringImmNode>();
1638
1639
    fragment_shapes[buffer] = shape_str->value;
  } else if (op->attr_key == tir::attr::fragment_layout) {
1640
1641
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *layout_str = op->value.as<StringImmNode>();
1642
1643
    fragment_layouts[buffer] = layout_str->value;
  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
1644
1645
1646
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1647
1648
1649
1650
1651
1652
1653
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
1654
1655
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1656
    auto wait_cnt = wait_attrs.second;
1657
1658
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
1659
1660
1661
1662
1663
1664
1665
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
1666
    const StringImmNode *pattern = op->value.as<StringImmNode>();
1667
1668
1669
1670
1671
1672
1673
1674
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

1675
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
1676
1677
1678
1679
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());
  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
1680
  const VarNode *buffer = op->buffer_var.as<VarNode>();
1681
1682
  if (scope.find("wmma.") == 0) {
    if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
1683
1684
1685
1686
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
             op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
             op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
1687
1688
1689
          << "Matrix_a and matrix_b only support half or char or unsigned char "
          << "or uint4 or int4 or int1 type for now";
    } else {
1690
1691
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
1692
1693
1694
          << "Accumulator only support half, float and int type for now";
    }
    PrintWmmaScope(scope, op->dtype, buffer, stream);
1695
  } else {
1696
1697
1698
1699
1700
1701
1702
1703
    PrintStorageScope(scope, stream);
    PrintType(op->dtype, stream);
  }

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
1704
    ICHECK_GT(constant_size, 0)
1705
1706
        << "Can only handle constant size stack allocation for now, but get "
        << constant_size << " for " << op->buffer_var->name_hint;
1707
1708
1709
1710
1711
1712
1713
1714
    if (scope.find("wmma.") == 0) {
      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
    }
    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
1715
1716
    if (scope == "shared") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
1717
1718
1719
1720
1721
1722
    } else if (scope == "shared.barrier") {
      auto v_id_mem = vid + "_mem";
      stream << ' ' << v_id_mem << "[" << constant_size << "];\n";
      PrintIndent();
      stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_
             << "*>(" << v_id_mem << ");\n";
1723
1724
1725
1726
1727
1728
1729
1730
    } else if (scope == "local") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
    } else if (scope == "local.var") {
      stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
             << ";\n";
    } else {
      ICHECK(false) << "Unsupported scope: " << scope;
    }
1731
1732
1733
1734
1735
1736
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
  if (is_const_int(op->value))
    return;
  const CallNode *call = op->value.as<CallNode>();
  if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
    PrintIndent();
    stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
    PrintIndent();
    stream << "if (threadIdx.x == 0) {\n";
    PrintIndent();
    stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
    PrintIndent();
    stream << "}\n";
  } else {
    CodeGenC::VisitStmt_(op);
  }
}

1755
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
1756
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1757
1758
  CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
                     << lanes << " lanes is not allowed.";
1759
1760
1761
1762
1763
1764
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1765
1766
    if (i != lanes - 1)
      os << ", ";
1767
1768
1769
1770
  }
  os << "))";
}

1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
                                     std::ostream &os) { // NOLINT(*)
  ICHECK_EQ(op->indices.size(), 1)
      << "Load from non-flat memory not supported.";
  ICHECK(!op->predicate.defined())
      << "Predicated buffer load is not supported.";

  DataType value_dtype = op->dtype;
  PrimExpr index = op->indices[0];
  Var buffer_var = op->buffer->data;
  DataType element_dtype = op->buffer->dtype;

  int lanes = op->dtype.lanes();
  // delcare type.
  if (value_dtype.lanes() == element_dtype.lanes()) {
    std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
    HandleVolatileLoads(ref, op, os);
  } else {
    bool can_vector_load = false;
    arith::PVar<PrimExpr> base;
    if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
      const RampNode *ramp = index.as<RampNode>();
      ICHECK(ramp);
      can_vector_load = true;
      // arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
      // The condition: {k * coeff + base} divisible by the alignment for any k
      // if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
      // == 0) {
      //   can_vector_load = true;
      // }
    }

    if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
      // A float4_e2m1fn element has 4 bits, which is an incomplete byte.
      // So we cannot vector load it.
      can_vector_load = false;
    }
    if (can_vector_load) {
      std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
      HandleVolatileLoads(ref, op, os);
    } else {
      std::ostringstream svalue_expr;
      std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
      std::string vid = GetVarID(buffer_var.get());
      DataType elem_type = op->dtype.element_of();
      for (int i = 0; i < lanes; ++i) {
        std::ostringstream value_temp;
        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
          value_temp << "((";
          if (buffer_var.get()->dtype.is_handle()) {
            auto it = alloc_storage_scope_.find(buffer_var.get());
            if (it != alloc_storage_scope_.end()) {
              PrintStorageScope(it->second, value_temp);
            }
          }
          PrintType(elem_type, value_temp);
          value_temp << "*)" << vid << ')';
        } else {
          value_temp << vid;
        }
        value_temp << '[';
        PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
        value_temp << ']';
        PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
      }
      os << svalue_expr.str();
    }
  }
}

1841
1842
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
                                     std::ostream &os) { // NOLINT(*)
1843
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1844
1845
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
      lanes == 4) {
1846
    // make_int8x4
1847
    const int64_t *p = as_const_int(op->value);
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
    ICHECK(p);
    int64_t v = *p & 0xFF;
    v = (v << 24) | (v << 16) | (v << 8) | v;
    if (op->dtype.is_uint()) {
      os << "(uint)" << v;
    } else {
      os << "(int)" << v;
    }
    return;
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1865
1866
      if (i != 0)
        os << ", ";
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
      os << "__pack_half2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1879
1880
      if (i != 0)
        os << ", ";
1881
1882
1883
1884
1885
1886
      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

1887
1888
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
1889
1890
1891
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
1892
1893
      if (i != 0)
        os << ", ";
1894
1895
1896
1897
1898
1899
1900
1901
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
1902
    const int64_t *p = as_const_int(op->value);
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
1914
1915
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
1927
1928
          if (i != 0)
            os << ", ";
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
1951
1952
    if (i != 0)
      os << ", ";
1953
1954
1955
1956
1957
    os << v;
  }
  os << ')';
}

1958
1959
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangCUDA *p) { // NOLINT(*)
1960
1961
1962
1963
1964
1965
  // Type code is kBFloat
  if (op->dtype.is_bfloat16()) {
    os << "bfloat16_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
1966
1967
1968
1969
1970
1971
  // Type code is kFloat8_e5m2 or kE4M4Float
  if (op->dtype.is_float8() || op->dtype.is_float4()) {
    p->PrintType(op->dtype, os);
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
1972
1973
  // Type code is kFloat
  switch (op->dtype.bits()) {
1974
1975
1976
1977
1978
1979
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
1980
      }
1981
      temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
1982
      p->need_math_constants_h_ = true;
1983
1984
    } else if (std::isnan(op->value)) {
      temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
1985
      p->need_math_constants_h_ = true;
1986
1987
1988
1989
    } else {
      temp << std::scientific << op->value;
      if (op->dtype.bits() == 32)
        temp << 'f';
1990
    }
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  case 16: {
    os << "half_t" << '(';
    FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
    PrintConst(const_f32.get(), os, p);
    os << ')';
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
2004
2005
2006
  }
}

2007
2008
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
                                     std::ostream &os) { // NOLINT(*)
2009
2010
2011
  PrintConst(op, os, this);
}

2012
2013
2014
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
                                         const VarNode *variable,
                                         std::ostream &os) {
2015
2016
  std::stringstream type;
  PrintType(t, type);
2017
2018
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
  std::string shape_str = fragment_shapes.at(variable);
  if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
    type.str(std::string());
    if (t.is_int()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::s4";
      } else if (t.bits() == 1) {
        type << "nvcuda::wmma::experimental::precision::b1";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    } else if (t.is_uint()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::u4";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    }
  }
  if (scope == "wmma.matrix_a") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
2041
2042
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
2043
2044
2045
  } else if (scope == "wmma.matrix_b") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
2046
2047
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
2048
  } else if (scope == "wmma.accumulator") {
2049
2050
    os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
       << ", " << type.str() << ">";
2051
2052
2053
  }
}

2054
2055
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
                                                 const VarNode *variable,
2056
                                                 int32_t size) {
2057
2058
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
2059
2060
2061
2062
2063
2064
2065
2066
  std::string shape_str = fragment_shapes.at(variable);
  std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
  if (dim.first * dim.second != 0)
    return size / dim.first / dim.second;
  else
    return 0;
}

2067
2068
2069
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
                                              const BufferLoadNode *op,
                                              std::ostream &os) {
2070
2071
2072
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
2073
2074
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
2075
2076
2077
2078
2079
2080
2081
2082
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

2083
2084
2085
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
                                               const std::string &value,
                                               std::ostream &os) {
2086
2087
2088
2089
2090
2091
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
2092
2093
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
                                                 const PrimFunc &func,
                                                 std::ostream &os) {
  PrintFuncPrefix(os);
  CodeGenC::PrintType(func->ret_type, os);
  CodeGenC::PrintExtraAttrs(func, os);
  bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
  os << " " << function_name << "(";
  for (size_t i = 0; i < func->params.size(); ++i) {
    tir::Var v = func->params[i];
    std::string vid = AllocVarID(v.get());

    if (i > 0) {
      os << ", ";
    }

    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (ptr->storage_scope == "grid_constant") {
          os << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, os);
          os << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, os);
      }

      CodeGenC::PrintType(GetType(v), os);
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, os);
      }
    } else {
      CodeGenC::PrintType(GetType(v), os);
    }
    os << ' ' << vid;
  }
  os << ")";

  // Register handle data type
  // TODO(tvm-team): consider simply keep type info in the
  // type annotation(via a normalizing rewriting).
  for (const auto &param : func->params) {
    if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
      if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
        RegisterHandleType(param.get(), prim->dtype);
      }
    }
  }
}

void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
                                      const PrimFunc &f) {
  // If the function has already been forward-declared, this is a
  // no-op.
  CodeGenC::DeclareFunction(gvar, f);
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
2228
2229
  this->PrintExtraAttrs(f);

2230
2231
2232
2233
2234
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
2235
2236
    if (i != 0)
      stream << ", ";
2237
2238
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
2239
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
2254
2255
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

2277
2278
} // namespace codegen
} // namespace tvm