codegen_cuda.cc 73.1 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file target/codegen.cc
 */

#include "codegen_cuda.h"
#include <tvm/arith/analyzer.h>
7
#include <tvm/ffi/function.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
17
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
#include "../op/bulk_copy.h"
18
#include "arith/pattern_match.h"
19
20
21
22
23
#include "target/source/ptx.h"

namespace tvm {
namespace codegen {

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
static std::string GetFP8Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "_2";
  } else if (lanes == 4) {
    vec = "_4";
  } else if (lanes == 8) {
    vec = "_8";
  } else if (lanes == 16) {
    vec = "_16";
  } else {
    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
                  "for FP8";
  }
42
43
  if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
      type.is_float8_e4m3()) {
44
    stream << "fp8_e4" << vec << "_t";
45
46
  } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
             type.is_float8_e5m2()) {
47
48
    stream << "fp8_e5" << vec << "_t";
  } else {
49
    LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
50
51
52
53
  }
  return stream.str();
}

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
std::string GetFP6Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP6";
  }
  stream << "__nv_fp6";
  std::string suffix;
  if (type.code() == DataType::kFloat6_e2m3fn) {
    suffix = "_e2m3";
  } else if (type.code() == DataType::kFloat6_e3m2fn) {
    suffix = "_e3m2";
  } else {
    LOG(FATAL) << "Unsupported FP6 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

std::string GetFP4Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP4";
  }
  stream << "__nv_fp4";
  std::string suffix;
  if (type.code() == DataType::kFloat4_e2m1fn) {
    suffix = "_e2m1";
  } else {
    LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

114
115
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
  restrict_keyword_ = "__restrict__";
116
117
118
119
120
  vid_global_barrier_state_ =
      name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
  vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
  ICHECK_EQ(vid_global_barrier_state_,
            runtime::symbol::tvm_global_barrier_state);
121
}
122

123
124
125
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
126
127

class LaunchConfigExtractor : public tir::StmtVisitor {
128
129
private:
  void VisitStmt_(const AttrStmtNode *op) final {
130
131
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
132
133
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
134
        threadIdx_x_ext = op->value;
135
136
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
137
        threadIdx_y_ext = op->value;
138
139
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
140
141
142
143
144
145
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

146
public:
147
148
149
150
151
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

152
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) {
153
154
155
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
156
157
158
159
160
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
161
    if (threadIdx_ext_int->value == 1) {
162
163
      // unable to extract the number of threads per block, hence directly
      // return
164
165
      return;
    }
166
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)";
167
168
169
170
171
172
173
  }
}

std::string CodeGenTileLangCUDA::Finish() {
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
174
175
176
177
178
179
180
181
  if (enable_fp8_) {
    decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
  }

  if (need_math_constants_h_) {
    decl_stream << "#include <math_constants.h>\n";
  }

182
183
184
185
  if (need_cooperative_groups_) {
    decl_stream << "#include <cooperative_groups.h>\n";
  }

186
  decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
187
188
189
  if (enable_sparse_gemm_) {
    decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
  }
190
191
192
193
  decl_stream << "#include <tl_templates/cuda/copy.h>\n";
  decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
  decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
  decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
194
  decl_stream << "#include <tl_templates/cuda/debug.h>\n";
195
196
197
  decl_stream << "#ifdef ENABLE_BF16\n";
  decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n";
  decl_stream << "#endif\n";
198
199

  if (need_global_barrier_) {
200
201
    decl_stream << "__device__ unsigned " << vid_global_barrier_state_
                << " = 0;\n";
202
  }
203
  decl_stream << "\n";
204

205
206
207
  return CodeGenC::Finish();
}

208
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
209
210
211
212
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
213
214
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
215
216
217
218
219
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
220
221
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
222
223
224
225
226
227
228
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

229
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
230
  ICHECK(!var_idmap_.count(iv->var.get()));
231
232
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
233
234
}

235
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
256
    case 16:
257
      enable_fp16_ = true;
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
      } else {
273
        fail = true;
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
300
    }
301
302
303
304
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
305
306
307
308
309
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
310
    enable_bf16_ = true;
311
312
313
314
315
316
317
318
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
    } else {
      fail = true;
    }
319
320
    if (!fail)
      return;
321
  } else if (t.is_float8()) {
322
323
324
    enable_fp8_ = true;
    os << GetFP8Type(t);
    return;
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
  } else if (t.is_float6()) {
    enable_fp6_ = true;
    if (t.lanes() <= 4) {
      os << GetFP6Type(t);
    } else {
      fail = true;
    }
    return;
  } else if (t.is_float4()) {
    enable_fp4_ = true;
    if (t.lanes() <= 4) {
      os << GetFP4Type(t);
    } else {
      fail = true;
    }
    return;
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
372
      }
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
396
      }
397
398
399
400
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.
401
        enable_int8_ = true;
402
403
404
405
406
407
408

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
409
        enable_int8_ = true;
410
411
412
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
413
        enable_int8_ = true;
414
415
416
417
        os << "int4";
        return;
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
418
        break;
419
420
      } else {
        os << "char";
421
422
        break;
      }
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
446
447
        return;
      }
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
467
        fail = true;
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
      }
      return;
    }
    default:
      fail = true;
      break;
489
490
491
492
493
494
495
496
497
498
499
500
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

501
502
503
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
                                           PrimExpr lhs, PrimExpr rhs,
                                           std::ostream &os) { // NOLINT(*)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

537
538
539
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
                                           int i,
                                           std::ostream &os) { // NOLINT(*)
540
541
542
543
544
545
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
546
547
548
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
549
550
551
552
553
554
555
556
557
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
    }
  } else if (t.is_float16()) {
558
559
    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
560
  } else if (t.is_bfloat16()) {
561
562
    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
581
582
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
583
584
585
586
587
  } else {
    os << vec << "." << access[i];
  }
}

588
589
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
                                            int i, const std::string &value) {
590
591
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
592
593
594
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
595
596
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
597
598
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
599
600
601
602
603
604
605
606
607
608
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
    }
  } else if (t.is_float16()) {
609
610
    stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
611
  } else if (t.is_bfloat16()) {
612
613
    stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
632
633
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
634
635
636
637
638
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

639
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
640
641
  auto args = op->args;
  const std::string &sync = args[0].as<StringImmNode>()->value;
642
643
644
645
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    if (args.size() == 1) {
      this->stream << "__syncthreads();\n";
    } else if (args.size() == 2) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n";
    } else if (args.size() == 3) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      auto thread_count = args[2].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ", "
                   << thread_count << ">();\n";
    } else {
      LOG(FATAL) << "Invalid number of arguments for storage sync: "
                 << args.size();
    }
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
  } else if (sync == "global") {
    if (!need_global_barrier_) {
      need_global_barrier_ = true;
    }
    // global synchronizer
    std::string is_load = PrintExpr(op->args[1]);
    std::string num_blocks = PrintExpr(op->args[2]);
    this->PrintIndent();
    // In theory only threadfence is needed
    // but we observed problems with only threadfence
    this->stream << "__threadfence_system();\n";
    this->PrintIndent();
    this->stream << "if (" << is_load << ") {\n";
    int wb = this->BeginScope();
    this->PrintIndent();
    this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
    this->PrintIndent();
    std::string ptr = name_supply_->FreshName("pf");
    this->stream << "volatile unsigned* " << ptr << " = &"
                 << vid_global_barrier_state_ << ";\n";
    this->PrintIndent();
    this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
    this->PrintIndent();
    this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
                 << ");\n";
    this->EndScope(wb);
    this->PrintIndent();
    this->stream << "}\n";
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
690
691
692
  }
}

693
694
695
696
697
void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
                                            std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
698
699
700
701
702
703
704
  if (scope == "shared") {
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

705
706
707
708
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
                                            DataType target) {
  if (from == target)
    return value;
709
710
711
712
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
713
714
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
715
716
717
718
719
720
721
722
723
724
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
  os << value << ")";
  return os.str();
}

725
void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
726
727
728
729
730
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
731
732
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
733
734
735
736
737
738
739

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
  std::string src = SSAGetID(PrintExpr(op->value), from_ty);

  // Handle bfloat16 special cases with supported ops
  bool used_bf16_op = false;
  if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
    std::ostringstream func_name;
    if (from_ty.is_bfloat16())
      func_name << "bf16";
    else if (from_ty.is_float())
      func_name << "float";
    if (from_ty.lanes() > 1)
      func_name << from_ty.lanes();
    func_name << "2";
    if (target_ty.is_bfloat16())
      func_name << "bf16";
    else if (target_ty.is_float())
      func_name << "float";
    else if (target_ty == DataType::Int(16))
      func_name << "int16";
    if (target_ty.lanes() > 1)
      func_name << target_ty.lanes();

    auto fname = func_name.str();
    if (bf16_supported_ops_.count(fname)) {
      used_bf16_op = true;
      stream << "#ifdef ENABLE_BF16\n";
      PrintIndent();
      stream << "reinterpret_cast<";
      if (target_ty.is_bfloat16())
        stream << "__nv_bfloat16";
      else
        PrintType(target_ty.element_of(), stream);
      if (target_ty.lanes() > 1)
        stream << target_ty.lanes();
      stream << " &>(" << sret << ") = fastertransformer::" << fname
             << "(reinterpret_cast<";
      if (from_ty.is_bfloat16())
        stream << "__nv_bfloat16";
      else
        PrintType(from_ty.element_of(), stream);
      if (from_ty.lanes() > 1)
        stream << from_ty.lanes();
      stream << " const &>(" << src << "));\n";
      stream << "#else\n";
784
785
    }
  }
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800

  // Fallback: elementwise cast
  for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
    std::ostringstream val;
    val << "(";
    PrintType(target_ty.element_of(), val);
    val << ")(";
    PrintVecElemLoad(src, from_ty, i, val);
    val << ")";
    PrintVecElemStore(sret, target_ty, i, val.str());
  }

  if (used_bf16_op) {
    stream << "#endif\n";
  }
801
802
803
  os << sret;
}

804
805
806
807
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
                                          const Array<PrimExpr> &args,
                                          bool skip_first_arg,
                                          std::ostream &os) { // NOLINT(*)
808
  DataType ret_dtype = GetRuntimeDataType(ret_type);
809
  if (ret_dtype.is_fixed_length_vector()) {
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
847
848
          if (j > 0)
            scall << ", ";
849
850
851
852
853
854
855
856
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
857
858
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
859
860
861
862
  }
}

// Print a reference expression to a buffer.
863
864
865
866
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
                                              const BufferNode *buffer,
                                              PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }
899
900
901
902
903
904
905
  if (scope.empty()) {
    scope = GetPtrStorageScope(buffer->data);
  }
  if (scope == "local.var") {
    os << vid;
    return os.str();
  }
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

928
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
929
930
  auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
                                    size_t end = 0) {
931
932
933
934
    // Cache context into a private ss, otherwise the let node may generate
    // within the function call arguments.
    std::ostringstream ss;

935
936
    for (size_t i = start; i < op->args.size() - end; i++) {
      if (i > start)
937
938
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
939
    }
940
941
942
943

    this->PrintIndent();
    this->stream << name << "(";
    this->stream << ss.str();
944
945
946
947
948
949
950
951
    this->stream << ");\n";
  };
  if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
952
953
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
954
955
    if (op->args.size() == 5) {
      this->PrintIndent();
956
957
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
958
959
960
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
961
962
963
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
964
965
966
967
968
969
970
971
972
973
974
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
  } else if (op->op.same_as(builtin::create_barriers())) {
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    std::string barrier_name = "_mbarrier";
975
976
    this->stream << "__shared__ uint64_t " << barrier_name << "["
                 << barrier_count << "];\n";
977
  } else if (op->op.same_as(tl::get_mbarrier())) {
978
979
980
981
982
983
984
985
986
987
988
    std::string barrier_name = "_mbarrier";
    std::string barrier_id = this->PrintExpr(op->args[0]);
    os << barrier_name + "[" + barrier_id + "]";
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    print_extern_call_stmt("tl::mbarrier_arrive");
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    print_extern_call_stmt("tl::mbarrier_init");
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
989
  } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
990
    print_extern_call_stmt("tl::mbarrier_expect_tx");
991
  } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
992
    print_extern_call_stmt("tl::mbarrier_wait");
993
  } else if (op->op.same_as(tl::sync_thread_partial())) {
994
    print_extern_call_stmt("cutlass::arch::NamedBarrier::sync");
995
996
  } else if (op->op.same_as(tl::no_set_max_nreg())) {
    return;
997
  } else if (op->op.same_as(tl::tma_load())) {
998
    std::ostringstream ss;
999
    ICHECK_GE(op->args.size(), 2);
1000
1001
1002
1003
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
    ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
1004
    auto desc = op->args[0];
1005
    ss << this->PrintExpr(desc) << ", ";
1006
    if (const IntImmNode *imm = op->args[1].as<IntImmNode>()) {
1007
      ss << "_mbarrier[" << imm->value << "], ";
1008
    } else {
1009
      ss << this->PrintExpr(op->args[1]) << ", ";
1010
    }
1011
    for (size_t i = 2; i < op->args.size() - 1; i++) {
1012
      if (i > 2)
1013
1014
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
1015
    }
1016
1017
1018
    ss << ");\n";
    this->PrintIndent();
    this->stream << ss.str();
1019
  } else if (op->op.same_as(tl::tma_load_im2col())) {
1020
1021
1022
1023
1024
1025
    std::stringstream ss;
    ss << "tl::tma_load_im2col<tl::CacheHintSm90::"
       << this->eviction_policy_names_
              [op->args[op->args.size() - 1].as<IntImmNode>()->value]
       << ">";
    print_extern_call_stmt(ss.str(), 0, 1);
1026
  } else if (op->op.same_as(tl::tma_store())) {
1027
1028
1029
1030
1031
1032
    std::stringstream ss;
    ss << "tl::tma_store<tl::CacheHintSm90::"
       << this->eviction_policy_names_
              [op->args[op->args.size() - 1].as<IntImmNode>()->value]
       << ">";
    print_extern_call_stmt(ss.str(), 0, 1);
1033
  } else if (op->op.same_as(tl::ptx_ldmatirx())) {
1034
1035
1036
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
1037
1038
    if (trans == 1)
      func_name += "_trans";
1039
    print_extern_call_stmt(func_name, 2);
1040
  } else if (op->op.same_as(tl::ptx_stmatirx())) {
1041
1042
1043
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
1044
1045
    if (trans == 1)
      func_name += "_trans";
1046
    print_extern_call_stmt(func_name, 2);
1047
  } else if (op->op.same_as(tl::fence_proxy_async())) {
1048
    print_extern_call_stmt("tl::fence_proxy_async");
1049
  } else if (op->op.same_as(tl::tma_store_arrive())) {
1050
    print_extern_call_stmt("tl::tma_store_arrive");
1051
  } else if (op->op.same_as(tl::tma_store_wait())) {
1052
    print_extern_call_stmt("tl::tma_store_wait<0>");
1053
  } else if (op->op.same_as(tl::set_max_nreg())) {
1054
1055
1056
    this->PrintIndent();
    int nreg = Downcast<IntImm>(op->args[0])->value;
    int is_inc = Downcast<IntImm>(op->args[1])->value;
1057
1058
    std::string func_name =
        is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
1059
    this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
1060
  } else if (op->op.same_as(tl::wait_wgmma())) {
1061
1062
1063
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
1064
  } else if (op->op.same_as(tl::pack_b16())) {
1065
1066
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
  } else if (op->op.same_as(tl::sync_grid())) {
    this->need_cooperative_groups_ = true;
    this->PrintIndent();
    this->stream << "cooperative_groups::grid_group grid = "
                    "cooperative_groups::this_grid();\n";
    this->PrintIndent();
    this->stream << "grid.sync();\n";
  } else if (op->op.same_as(tl::loop_break())) {
    this->PrintIndent();
    this->stream << "break;\n";
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
1110
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::ptx_mma())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp64, ...
    // arg 4: B precision: fp16, fp64, ...
    // arg 5: C precision: fp32, fp64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index
    // arg 12: saturate
    // arg 13: (optional) 1-bit operator (xor or and)
    ICHECK(op->args.size() == 13U || op->args.size() == 14U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
    bool saturate = Downcast<Bool>(op->args[12])->value;
1165
1166
1167
1168
1169
    std::string bit_op =
        op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
    std::string asm_code = PrintMMAAssembly(
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
        b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206

    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_mma_sp())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp32, ...
    // arg 4: B precision: fp16, fp32, ...
    // arg 5: C precision: fp16, fp32, ...
    // arg 6: A multiplicand pointer
    // arg 7: A multiplicand index
    // arg 8: B multiplicand pointer
    // arg 9: B multiplicand index
    // arg 10: C accumulator pointer
    // arg 11: C accumulator index
    // arg 12: metadata
    // arg 13: metadata index
    // arg 14: sparse_selector
    // arg 15: saturate
    ICHECK_EQ(op->args.size(), 16U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_offset = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    std::string metadata = this->PrintExpr(op->args[12]);
    std::string metadata_offset = this->PrintExpr(op->args[13]);
    std::string sparse_selector = this->PrintExpr(op->args[14]);
    bool saturate = Downcast<Bool>(op->args[15])->value;
    std::string asm_code = PrintMMAAssembly(
1207
1208
1209
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
        b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
        sparse_selector, "", true, saturate);
1210
1211
1212
1213
1214
1215
1216
1217
    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
    // arg 0: whether the matrix is loaded in column major format or not.
    // arg 1: number of matrices to load.
    // arg 2: The data type in the matrix, .b16 is the only accepted data type.
    // arg 3: pointer to local buffer.
    // arg 4: The offset of the element to store in the local buffer.
    // arg 5: pointer to the shared memory buffer to load.
1218
1219
    // arg 6: The offset of the start element of the row to load in shared
    // memory.
1220
1221
1222
1223
1224
1225
1226
1227
    ICHECK_EQ(op->args.size(), 7U);
    bool trans = Downcast<Bool>(op->args[0])->value;
    int num = Downcast<Integer>(op->args[1])->value;
    std::string type = Downcast<StringImm>(op->args[2])->value;
    std::string local_ptr = this->PrintExpr(op->args[3]);
    std::string local_elem_offset = this->PrintExpr(op->args[4]);
    std::string smem_ptr = this->PrintExpr(op->args[5]);
    if (trans && op->dtype.bits() == 8) {
1228
1229
      // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
      // properly transpose an int8 matrix.
1230
1231
1232
1233
      std::string smem_stride = this->PrintExpr(op->args[6]);
      ICHECK(num == 4);
      os << "for (int i = 0; i < 16; ++i) {\n";
      os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
1234
1235
1236
1237
         << "[(i % 8) / 4 * " + smem_stride +
                " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
                "+ (i % 4) * " + smem_stride +
                " + threadIdx.x / 4 +  (i / 8) * 8];\n";
1238
1239
1240
1241
      os << "}\n";
    } else {
      std::string smem_elem_offset = this->PrintExpr(op->args[6]);
      need_cast_smem_ptr_to_int_ = true;
1242
1243
1244
      this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
                                              local_elem_offset, smem_ptr,
                                              smem_elem_offset);
1245
1246
1247
1248
1249
1250
1251
1252
1253
    }
  } else if (op->op.same_as(builtin::mma_store())) {
    int m = Downcast<Integer>(op->args[0])->value;
    int n = Downcast<Integer>(op->args[1])->value;
    std::string dst = this->PrintExpr(op->args[2]);
    std::string src = this->PrintExpr(op->args[3]);
    std::string src_offset = this->PrintExpr(op->args[4]);
    PrimExpr stride = op->args[5];

1254
1255
    ICHECK(m == 16 && n == 16)
        << "Only m == 16 && n == 16 case supported for now";
1256

1257
1258
1259
1260
1261
    // Each thread in a warp holds a certain number of elements of an MMA
    // output. For example, if we compute a 16x16 tile using MMA, each thread
    // holds 8 elements in its registers. So conceptually, a warp memory is
    // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
    // memory is specified by the index map below.
1262

1263
1264
    // To store the 32x8 output back to a 16x16 tile in shared or global memory,
    // we invert this map to determine the output location for each 8 element.
1265

1266
1267
    const auto index_map_func = ffi::Function::GetGlobal(
        "tir.index_map.shared_16x16_to_mma_32x8_layout");
1268

1269
1270
1271
    IndexMap index_map;
    if (!index_map_func) {
      Var i, j;
1272

1273
      // The index map is defined as follows:
1274
1275
1276
1277
1278
      index_map = IndexMap(
          {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
                   4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
    } else {
      index_map = IndexMap::FromFunc(2, *index_map_func);
1279
1280
1281
1282
1283
1284
1285
    }

    arith::Analyzer analyzer;
    auto inverse_index_map =
        index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
    auto indices_16x16 = inverse_index_map->final_indices;

1286
1287
1288
    // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
    // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
    // they reach codegen, so manually replace them to the plain ones here.
1289
    class LowerFloorDivMod : public ExprMutator {
1290
1291
    public:
      PrimExpr VisitExpr_(const FloorDivNode *op) {
1292
1293
        return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
1294
      PrimExpr VisitExpr_(const FloorModNode *op) {
1295
1296
1297
1298
        return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
    };

1299
1300
    auto dst_ind =
        LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
1301
1302
1303
1304
1305
1306
1307
1308
1309

    var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
    var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
    if (op->dtype.bits() == 16) {
      os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n";
      os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])"
         << " = "
         << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
      os << "}\n";
1310
    } else {
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
      os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
      os << dst << "[" + this->PrintExpr(dst_ind) + "]"
         << " = " << src << "[" << src_offset << " + local_id];\n";
      os << "}\n";
    }

  } else if (op->op.same_as(builtin::mma_fill())) {
    std::string num_elem = this->PrintExpr(op->args[0]);
    std::string dst = this->PrintExpr(op->args[1]);
    std::string dst_offset = this->PrintExpr(op->args[2]);

    os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
    os << dst << "[" << dst_offset << " + i] = 0.0;";
    os << "}\n";
  } else if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    need_cast_smem_ptr_to_int_ = true;
1332
1333
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
1334
    if (op->args.size() == 5) {
1335
1336
      this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
                                           size);
1337
    } else {
1338
1339
      this->stream << PrintPredicatedCpAsyncAssembly(
          dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
    }
  } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
    need_cast_smem_ptr_to_int_ = true;
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    int barrier_id = Downcast<IntImm>(op->args[5])->value;
    CHECK(barrier_id < barrier_count_);
1350
1351
1352
1353
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
                                        barrier);
1354
1355
1356
1357
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
1358
1359
    this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
                 << ";\");\n\n";
1360
1361
1362
1363
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1364
1365
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1366
1367
1368
1369
1370
    this->stream << PrintCpAsyncBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1371
1372
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1373
1374
1375
1376
1377
1378
    std::string thread_count = this->PrintExpr(op->args[1]);
    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1379
1380
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1381
1382
1383
1384
1385
    this->stream << PrintArriveBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1386
1387
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1388
1389
1390
1391
1392
1393
    std::string byte_count = this->PrintExpr(op->args[1]);
    this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
  } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1394
1395
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1396
1397
1398
1399
1400
1401
1402
1403
    this->stream << PrintWaitBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::create_barriers())) {
    CHECK_EQ(barrier_count_, -1);
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    // pad barrier alignment to avoid runtime alignment errors
    CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
    int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
    if (barrier_count % barrier_alignment_count != 0) {
1404
1405
      barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
                      barrier_alignment_count;
1406
1407
    }
    barrier_count_ = barrier_count;
1408
1409
1410
1411
1412
    this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
                 << ") uint64_t " << barrier_name_ << "[" << barrier_count
                 << "];\n";
    this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
                 << barrier_name_ << "[i] = 0; }\n";
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
  } else if (op->op.same_as(builtin::ptx_ldg32())) {
    /*
    asm volatile (
        "{.reg .pred p;\n"
        " setp.ne.b32 p, %2, 0;\n"
        // " @p ld.global.nc.f32 %0, [%1];}\n"t
        " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
        : "=f"(reg)
        : "l"(addr), "r"((int)guard)
    );
    */

    // get local
    std::string reg = this->PrintExpr(op->args[0]);
    // get guard
    std::string guard = this->PrintExpr(op->args[1]);
1429
    const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
    std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
    std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
    std::string local_addr = this->PrintExpr(op->args[3]);
    this->stream << "asm volatile (\n";
    this->stream << "\"{.reg .pred p;\\n\"\n";
    this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
    this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
    this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
    // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
    stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
           << ")\n";
1441
1442
    stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
           << ")), \"r\"((int)" << guard << ")\n";
1443
    stream << ");\n";
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
  } else if (op->op.same_as(builtin::reinterpret())) {
    DataType tgt_dtype = op->dtype;
    DataType src_dtype = op->args[0]->dtype;
    PrimExpr value = op->args[0];

    // Handle float4_e2m1fn reinterpret
    if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() ==
                                      src_dtype.lanes() * src_dtype.bits()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of lanes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
    CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of bytes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;

    int lanes = tgt_dtype.lanes();

    int ssa_scope = BeginScope();
    if (lanes == 1) {
      // The case of lane=1 is same as the normal reinterpret,
      // except that we allow the src and dst dtype to have different number of
      // bits.
      std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
      os << "(*(";
      this->PrintType(tgt_dtype, os);
      os << " *)(&(" << rhs << ")))";
    } else if (lanes == 2) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint16, and then extract bits of two fp4
        // numbers, and finally reinterpret the result as fp4x2.
        value =
            tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     tir::Cast(DataType::UInt(8),
                               (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                                   ((temp_var >> 4) &
                                    IntImm(DataType::UInt(16), 0xF0))));
      } else {
        value = tir::Cast(
            DataType::UInt(16),
            tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                         ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else if (lanes == 4) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint32, and then extract bits of four fp4
        // numbers, and finally reinterpret the result as fp4x4.
        value =
            tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            tir::Cast(
                DataType::UInt(16),
                (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                    ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
                    ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
                    ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
      } else {
        value = tir::Cast(DataType::UInt(32),
                          tir::Call(DataType::UInt(16),
                                    tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else {
      LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
                 << lanes;
    }
    EndScope(ssa_scope);
  } else if (op->op.same_as(builtin::thread_return())) {
    os << "return";
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
  } else if (op->op.same_as(tl::tl_gemm())) {
    ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
                                    "A_ptr, B_ptr, C_ptr>, but got "
                                 << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
    this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
                          op->args, true, os);
  } else if (op->op.same_as(tl::tl_gemm_sp())) {
    ICHECK(op->args.size() == 5)
        << "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
           "E_ptr>, but got "
        << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
    enable_sparse_gemm_ = true;
    this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
                          op->args, true, os);
1554
1555
  } else if (op->op.same_as(tl::tl_shuffle_elect())) {
    os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
1556
1557
1558
1559
1560
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

1561
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
1562
  if (op->attr_key == tir::attr::fragment_shape) {
1563
1564
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *shape_str = op->value.as<StringImmNode>();
1565
1566
    fragment_shapes[buffer] = shape_str->value;
  } else if (op->attr_key == tir::attr::fragment_layout) {
1567
1568
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *layout_str = op->value.as<StringImmNode>();
1569
1570
    fragment_layouts[buffer] = layout_str->value;
  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
1571
1572
1573
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1574
1575
1576
1577
1578
1579
1580
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
1581
1582
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1583
    auto wait_cnt = wait_attrs.second;
1584
1585
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
1586
1587
1588
1589
1590
1591
1592
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
1593
    const StringImmNode *pattern = op->value.as<StringImmNode>();
1594
1595
1596
1597
1598
1599
1600
1601
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

1602
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
1603
1604
1605
1606
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());
  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
1607
  const VarNode *buffer = op->buffer_var.as<VarNode>();
1608
1609
  if (scope.find("wmma.") == 0) {
    if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
1610
1611
1612
1613
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
             op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
             op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
1614
1615
1616
          << "Matrix_a and matrix_b only support half or char or unsigned char "
          << "or uint4 or int4 or int1 type for now";
    } else {
1617
1618
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
1619
1620
1621
          << "Accumulator only support half, float and int type for now";
    }
    PrintWmmaScope(scope, op->dtype, buffer, stream);
1622
  } else {
1623
1624
1625
1626
1627
1628
1629
1630
    PrintStorageScope(scope, stream);
    PrintType(op->dtype, stream);
  }

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
1631
    ICHECK_GT(constant_size, 0)
1632
1633
        << "Can only handle constant size stack allocation for now, but get "
        << constant_size << " for " << op->buffer_var->name_hint;
1634
1635
1636
1637
1638
1639
1640
1641
    if (scope.find("wmma.") == 0) {
      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
    }
    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
    if (scope == "shared") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
    } else if (scope == "local") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
    } else if (scope == "local.var") {
      stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0))
             << ";\n";
    } else {
      ICHECK(false) << "Unsupported scope: " << scope;
    }
1652
1653
1654
1655
1656
1657
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
  if (is_const_int(op->value))
    return;
  const CallNode *call = op->value.as<CallNode>();
  if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
    PrintIndent();
    stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
    PrintIndent();
    stream << "if (threadIdx.x == 0) {\n";
    PrintIndent();
    stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
    PrintIndent();
    stream << "}\n";
  } else {
    CodeGenC::VisitStmt_(op);
  }
}

1676
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
1677
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1678
1679
  CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
                     << lanes << " lanes is not allowed.";
1680
1681
1682
1683
1684
1685
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1686
1687
    if (i != lanes - 1)
      os << ", ";
1688
1689
1690
1691
  }
  os << "))";
}

1692
1693
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
                                     std::ostream &os) { // NOLINT(*)
1694
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1695
1696
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
      lanes == 4) {
1697
    // make_int8x4
1698
    const int64_t *p = as_const_int(op->value);
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
    ICHECK(p);
    int64_t v = *p & 0xFF;
    v = (v << 24) | (v << 16) | (v << 8) | v;
    if (op->dtype.is_uint()) {
      os << "(uint)" << v;
    } else {
      os << "(int)" << v;
    }
    return;
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1716
1717
      if (i != 0)
        os << ", ";
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
      os << "__pack_half2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1730
1731
      if (i != 0)
        os << ", ";
1732
1733
1734
1735
1736
1737
      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

1738
1739
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
1740
1741
1742
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
1743
1744
      if (i != 0)
        os << ", ";
1745
1746
1747
1748
1749
1750
1751
1752
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
1753
    const int64_t *p = as_const_int(op->value);
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
1765
1766
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
1778
1779
          if (i != 0)
            os << ", ";
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
1802
1803
    if (i != 0)
      os << ", ";
1804
1805
1806
1807
1808
    os << v;
  }
  os << ')';
}

1809
1810
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangCUDA *p) { // NOLINT(*)
1811
1812
1813
1814
1815
1816
  // Type code is kBFloat
  if (op->dtype.is_bfloat16()) {
    os << "bfloat16_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
1817
1818
1819
1820
1821
1822
  // Type code is kFloat8_e5m2 or kE4M4Float
  if (op->dtype.is_float8() || op->dtype.is_float4()) {
    p->PrintType(op->dtype, os);
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
1823
1824
  // Type code is kFloat
  switch (op->dtype.bits()) {
1825
1826
1827
1828
1829
1830
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
1831
      }
1832
      temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
1833
      p->need_math_constants_h_ = true;
1834
1835
    } else if (std::isnan(op->value)) {
      temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
1836
      p->need_math_constants_h_ = true;
1837
1838
1839
1840
    } else {
      temp << std::scientific << op->value;
      if (op->dtype.bits() == 32)
        temp << 'f';
1841
    }
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  case 16: {
    os << "half_t" << '(';
    FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
    PrintConst(const_f32.get(), os, p);
    os << ')';
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
1855
1856
1857
  }
}

1858
1859
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
                                     std::ostream &os) { // NOLINT(*)
1860
1861
1862
  PrintConst(op, os, this);
}

1863
1864
1865
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
                                         const VarNode *variable,
                                         std::ostream &os) {
1866
1867
  std::stringstream type;
  PrintType(t, type);
1868
1869
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
  std::string shape_str = fragment_shapes.at(variable);
  if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
    type.str(std::string());
    if (t.is_int()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::s4";
      } else if (t.bits() == 1) {
        type << "nvcuda::wmma::experimental::precision::b1";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    } else if (t.is_uint()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::u4";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    }
  }
  if (scope == "wmma.matrix_a") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
1892
1893
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
1894
1895
1896
  } else if (scope == "wmma.matrix_b") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
1897
1898
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
1899
  } else if (scope == "wmma.accumulator") {
1900
1901
    os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
       << ", " << type.str() << ">";
1902
1903
1904
  }
}

1905
1906
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
                                                 const VarNode *variable,
1907
                                                 int32_t size) {
1908
1909
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
1910
1911
1912
1913
1914
1915
1916
1917
  std::string shape_str = fragment_shapes.at(variable);
  std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
  if (dim.first * dim.second != 0)
    return size / dim.first / dim.second;
  else
    return 0;
}

1918
1919
1920
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
                                              const BufferLoadNode *op,
                                              std::ostream &os) {
1921
1922
1923
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
1924
1925
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
1926
1927
1928
1929
1930
1931
1932
1933
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

1934
1935
1936
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
                                               const std::string &value,
                                               std::ostream &os) {
1937
1938
1939
1940
1941
1942
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
1943
1944
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
                                                 const PrimFunc &func,
                                                 std::ostream &os) {
  PrintFuncPrefix(os);
  CodeGenC::PrintType(func->ret_type, os);
  CodeGenC::PrintExtraAttrs(func, os);
  bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
  os << " " << function_name << "(";
  for (size_t i = 0; i < func->params.size(); ++i) {
    tir::Var v = func->params[i];
    std::string vid = AllocVarID(v.get());

    if (i > 0) {
      os << ", ";
    }

    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (ptr->storage_scope == "grid_constant") {
          os << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, os);
          os << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, os);
      }

      CodeGenC::PrintType(GetType(v), os);
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, os);
      }
    } else {
      CodeGenC::PrintType(GetType(v), os);
    }
    os << ' ' << vid;
  }
  os << ")";

  // Register handle data type
  // TODO(tvm-team): consider simply keep type info in the
  // type annotation(via a normalizing rewriting).
  for (const auto &param : func->params) {
    if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
      if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
        RegisterHandleType(param.get(), prim->dtype);
      }
    }
  }
}

void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
                                      const PrimFunc &f) {
  // If the function has already been forward-declared, this is a
  // no-op.
  CodeGenC::DeclareFunction(gvar, f);
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
2079
2080
  this->PrintExtraAttrs(f);

2081
2082
2083
2084
2085
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
2086
2087
    if (i != 0)
      stream << ", ";
2088
2089
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
2090
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
2105
2106
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

2128
2129
} // namespace codegen
} // namespace tvm