codegen_cuda.cc 109 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file target/codegen.cc
 */

#include "codegen_cuda.h"
#include <tvm/arith/analyzer.h>
7
#include <tvm/ffi/function.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
17
#include "./ptx.h"
18
#include "arith/pattern_match.h"
19
20
21

namespace tvm {
namespace codegen {
22
using namespace tvm::tl::codegen;
23
using namespace ffi;
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
struct CUDAMath {
  std::string operator()(DataType t, std::string name) const {
    if (t.is_float()) {
      switch (t.bits()) {
      case 64:
        return name;
      case 32:
        return name + 'f';
      case 16: {
        if (name == "fabs") {
          return "__habs";
        } else if (name == "round") {
          return "hrint";
        } else {
          return "h" + name;
        }
      }
      default:
        return "";
      }
    } else if (t.is_bfloat16()) {
      if (name == "fabs") {
        return "__habs";
      } else if (name == "round") {
        return "hrint";
      } else {
        return "h" + name;
      }
    } else if (t.is_int() || t.is_uint()) {
      switch (t.bits()) {
      case 32:
        return "__" + name;
      case 64:
        return "__" + name + "ll";
      default:
        return "";
      }
    }
    return "";
  }
};

struct CUDAFastMath : public CUDAMath {
  std::string operator()(DataType t, std::string name) const {
    if (t.is_float() && t.bits() == 32) {
      return "__" + name + 'f';
    } else {
      return CUDAMath::operator()(t, name);
    }
    return "";
  }
};

struct CUDAFastMathTan : public CUDAMath {
  std::string operator()(DataType t, std::string name) const {
    if (t.is_float()) {
      switch (t.bits()) {
      case 64:
        return name;
      // `__tanf` seems to produce some values too deviant from numpy tan
      // version. So, let's use just `tanf` instead.
      case 32:
        return name + 'f';
      case 16:
        return 'h' + name;
      default:
        return "";
      }
    }
    return "";
  }
};

98
99
100
101
102
103
104
105
106
107
108
109
struct CUDAIEEEMath {
  std::string operator()(DataType t, std::string name,
                         std::string rounding_mode) const {
    if (t.is_float() && t.bits() == 32) {
      return "__" + name + "_" + rounding_mode;
    } else if (t.is_float() && t.bits() == 64) {
      return "__d" + name + "_" + rounding_mode;
    }
    return "";
  }
};

110
111
112
113
114
115
116
117
118
119
120
121
122
123
static std::string GetFP8Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "_2";
  } else if (lanes == 4) {
    vec = "_4";
  } else if (lanes == 8) {
    vec = "_8";
  } else if (lanes == 16) {
    vec = "_16";
124
125
  } else if (lanes == 32) {
    vec = "_32";
126
  } else {
127
128
129
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4, 8, 16, 32) "
           "for FP8";
130
  }
131
132
  if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
      type.is_float8_e4m3()) {
133
    stream << "fp8_e4" << vec << "_t";
134
135
  } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
             type.is_float8_e5m2()) {
136
137
    stream << "fp8_e5" << vec << "_t";
  } else {
138
    LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
139
140
141
142
  }
  return stream.str();
}

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
std::string GetFP6Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP6";
  }
  stream << "__nv_fp6";
  std::string suffix;
  if (type.code() == DataType::kFloat6_e2m3fn) {
    suffix = "_e2m3";
  } else if (type.code() == DataType::kFloat6_e3m2fn) {
    suffix = "_e3m2";
  } else {
    LOG(FATAL) << "Unsupported FP6 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

std::string GetFP4Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "x2";
  } else if (lanes == 4) {
    vec = "x4";
  } else if (lanes == 8) {
    vec = "x8";
  } else if (lanes == 16) {
    vec = "x16";
  } else {
    LOG(FATAL)
        << "Only support scalar and vector types of width (2, 4) for FP4";
  }
  stream << "__nv_fp4";
  std::string suffix;
  if (type.code() == DataType::kFloat4_e2m1fn) {
    suffix = "_e2m1";
  } else {
    LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
  }
  stream << vec << suffix;
  return stream.str();
}

203
204
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
  restrict_keyword_ = "__restrict__";
205
206
207
208
209
  vid_global_barrier_state_ =
      name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
  vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
  ICHECK_EQ(vid_global_barrier_state_,
            runtime::symbol::tvm_global_barrier_state);
210
}
211

212
213
214
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
215
216

class LaunchConfigExtractor : public tir::StmtVisitor {
217
218
private:
  void VisitStmt_(const AttrStmtNode *op) final {
219
220
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
221
222
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
223
        threadIdx_x_ext = op->value;
224
225
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
226
        threadIdx_y_ext = op->value;
227
228
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
229
230
231
232
233
234
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

235
public:
236
237
238
239
240
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

241
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) {
242
243
244
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
245
246
247
248
249
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
250
    if (threadIdx_ext_int->value == 1) {
251
252
      // unable to extract the number of threads per block, hence directly
      // return
253
254
      return;
    }
255
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)";
256
257
258
259
260
261
262
  }
}

std::string CodeGenTileLangCUDA::Finish() {
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
263
264
265
266
267
268
269
270
  if (enable_fp8_) {
    decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
  }

  if (need_math_constants_h_) {
    decl_stream << "#include <math_constants.h>\n";
  }

271
272
273
274
  if (need_cooperative_groups_) {
    decl_stream << "#include <cooperative_groups.h>\n";
  }

275
  decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
276
277
278
  if (enable_sparse_gemm_) {
    decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
  }
279
280
281
282
  decl_stream << "#include <tl_templates/cuda/copy.h>\n";
  decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
  decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
  decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
283
  decl_stream << "#include <tl_templates/cuda/debug.h>\n";
284
285
286
  decl_stream << "#ifdef ENABLE_BF16\n";
  decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n";
  decl_stream << "#endif\n";
287
288

  if (need_global_barrier_) {
289
290
    decl_stream << "__device__ unsigned " << vid_global_barrier_state_
                << " = 0;\n";
291
  }
292
  decl_stream << "\n";
293

294
295
296
  return CodeGenC::Finish();
}

297
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
298
299
300
301
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
302
303
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
304
305
306
307
308
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
309
310
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
311
312
313
314
315
316
317
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

318
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
319
  ICHECK(!var_idmap_.count(iv->var.get()));
320
321
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
322
323
}

324
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
345
    case 16:
346
      enable_fp16_ = true;
347
348
349
350
351
352
353
354
355
356
357
358
359
360
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
361
362
363
364
      } else if (lanes <= 16) {
        ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half "
                                   "type of more than 8 lanes";
        os << "ulonglong" << lanes / 4;
365
      } else {
366
        fail = true;
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
393
    }
394
395
396
397
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
398
399
400
401
402
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
403
    enable_bf16_ = true;
404
405
406
407
408
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
409
410
411
412
    } else if (lanes <= 16) {
      ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half type "
                                 "of more than 8 lanes";
      os << "ulonglong" << lanes / 4;
413
414
415
    } else {
      fail = true;
    }
416
417
    if (!fail)
      return;
418
  } else if (t.is_float8()) {
419
420
421
    enable_fp8_ = true;
    os << GetFP8Type(t);
    return;
422
423
424
425
426
427
428
429
430
431
432
433
  } else if (t.is_float6()) {
    enable_fp6_ = true;
    if (t.lanes() <= 4) {
      os << GetFP6Type(t);
    }
    return;
  } else if (t.is_float4()) {
    enable_fp4_ = true;
    if (t.lanes() <= 4) {
      os << GetFP4Type(t);
    }
    return;
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
465
      }
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
489
      }
490
491
492
493
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.
494
        enable_int8_ = true;
495
496
497
498
499
500
501

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
502
        enable_int8_ = true;
503
504
505
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
506
        enable_int8_ = true;
507
508
        os << "int4";
        return;
509
510
511
512
      } else if (t.lanes() == 32) {
        enable_int8_ = true;
        os << "longlong4";
        return;
513
514
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
515
        break;
516
517
      } else {
        os << "char";
518
519
        break;
      }
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
543
544
        return;
      }
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
564
        fail = true;
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
580
581
      } else {
        fail = true;
582
      }
583
584
585
586
      if (!fail) {
        return;
      }
      break;
587
588
589
590
    }
    default:
      fail = true;
      break;
591
592
593
594
595
596
597
598
599
600
601
602
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

603
604
605
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
                                           PrimExpr lhs, PrimExpr rhs,
                                           std::ostream &os) { // NOLINT(*)
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

639
640
641
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
                                           int i,
                                           std::ostream &os) { // NOLINT(*)
642
643
644
645
646
647
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
648
  ICHECK(i >= 0 && i < 256 / t.bits());
649
650
651
652
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
653
    } else if (t.lanes() <= 16) {
654
655
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
656
657
658
659
    } else {
      ICHECK(t.lanes() == 32);
      std::string ac = vec + "." + access[i / 8];
      os << "((" << type_name << ")(" << ac << " >> " << i % 8 * 8 << "))";
660
661
    }
  } else if (t.is_float16()) {
662
663
664
665
666
667
668
    if (t.lanes() <= 8) {
      os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
         << access[i % 2];
    } else {
      os << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + "
         << (i / 2 % 2) << ")->" << access[i % 2];
    }
669
  } else if (t.is_bfloat16()) {
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
    if (t.lanes() <= 8) {
      os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
         << access[i % 2];
    } else {
      os << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] << "))) + "
         << (i / 2 % 2) << ")->" << access[i % 2];
    }
  } else if (t.is_float8()) {
    os << vec;
    // fp8_e5_32_t
    if (t.lanes() >= 32)
      os << "." << access[i / 16];
    // fp8_e5_16_t
    if (t.lanes() >= 16)
      os << "." << access[(i % 16) / 8];
    // fp8_e5_8_t
    if (t.lanes() >= 8)
      os << "." << access[(i % 8) / 4];
    // fp8_e5_4_t or fp8_e5_2_t
    os << "." << access[i % 4];
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
708
709
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
710
711
712
713
714
  } else {
    os << vec << "." << access[i];
  }
}

715
716
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
                                            int i, const std::string &value) {
717
718
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
719
  ICHECK(i >= 0 && i < 256 / t.bits());
720
721
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
722
723
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
724
    } else if (t.lanes() <= 16) {
725
726
727
728
729
730
731
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
732
733
734
735
736
737
738
739
740
    } else {
      ICHECK(t.lanes() == 32);
      std::string ac = vec + "." + access[i / 8];
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 8 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 8 * 8 << ");\n";
741
742
    }
  } else if (t.is_float16()) {
743
744
745
746
747
748
749
750
    if (t.lanes() <= 8) {
      stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
             << access[i % 2] << " = " << value << ";\n";
    } else {
      stream << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + "
             << (i / 2 % 2) << ")->" << access[i % 2] << " = " << value
             << ";\n";
    }
751
  } else if (t.is_bfloat16()) {
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
    if (t.lanes() <= 8) {
      stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
             << access[i % 2] << " = " << value << ";\n";
    } else {
      stream << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4]
             << "))) + " << (i / 2 % 2) << ")->" << access[i % 2] << " = "
             << value << ";\n";
    }
  } else if (t.is_float8()) {
    stream << vec;
    // fp8_e5_32_t
    if (t.lanes() >= 32)
      stream << "." << access[i / 16];
    // fp8_e5_16_t
    if (t.lanes() >= 16)
      stream << "." << access[(i % 16) / 8];
    // fp8_e5_8_t
    if (t.lanes() >= 8)
      stream << "." << access[(i % 8) / 4];
    // fp8_e5_4_t or fp8_e5_2_t
    stream << "." << access[i % 4] << " = " << value << ";\n";
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
791
792
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
793
794
795
796
797
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

798
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
799
800
  auto args = op->args;
  const std::string &sync = args[0].as<StringImmNode>()->value;
801
802
803
804
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
805
806
807
808
809
810
811
812
813
814
815
816
817
818
    if (args.size() == 1) {
      this->stream << "__syncthreads();\n";
    } else if (args.size() == 2) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n";
    } else if (args.size() == 3) {
      auto barrier_id = args[1].as<IntImmNode>()->value;
      auto thread_count = args[2].as<IntImmNode>()->value;
      this->stream << "tl::__sync_thread_partial<" << barrier_id << ", "
                   << thread_count << ">();\n";
    } else {
      LOG(FATAL) << "Invalid number of arguments for storage sync: "
                 << args.size();
    }
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
  } else if (sync == "global") {
    if (!need_global_barrier_) {
      need_global_barrier_ = true;
    }
    // global synchronizer
    std::string is_load = PrintExpr(op->args[1]);
    std::string num_blocks = PrintExpr(op->args[2]);
    this->PrintIndent();
    // In theory only threadfence is needed
    // but we observed problems with only threadfence
    this->stream << "__threadfence_system();\n";
    this->PrintIndent();
    this->stream << "if (" << is_load << ") {\n";
    int wb = this->BeginScope();
    this->PrintIndent();
    this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
    this->PrintIndent();
    std::string ptr = name_supply_->FreshName("pf");
    this->stream << "volatile unsigned* " << ptr << " = &"
                 << vid_global_barrier_state_ << ";\n";
    this->PrintIndent();
    this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
    this->PrintIndent();
    this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
                 << ");\n";
    this->EndScope(wb);
    this->PrintIndent();
    this->stream << "}\n";
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
849
850
851
  }
}

852
853
854
855
856
void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
                                            std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
857
  if (scope == "shared" || scope == "shared.barrier") {
858
859
860
861
862
863
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

864
865
866
867
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
                                            DataType target) {
  if (from == target)
    return value;
868
869
870
871
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
872
873
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
874
875
876
877
878
879
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
880
881
882
  if ((from.is_float16() || from.is_bfloat16()) && target.is_float8()) {
    os << "(float)";
  }
883
884
885
886
  os << value << ")";
  return os.str();
}

887
void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
888
889
890
891
892
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
893
894
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
895
896
897
898
899
900
901

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
902
903
  std::string src = SSAGetID(PrintExpr(op->value), from_ty);

904
905
906
907
908
  // Handle conversion between float16 and float32
  if (from_ty.is_float16() && target_ty.is_float()) {
    // Use __half22float2 for vectorized conversion (half2 -> float2)
    if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
      // half2 -> float2
909
      PrintIndent();
910
911
912
913
914
915
916
917
918
919
920
921
922
      stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
      os << sret;
      return;
    } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
      // half4 -> float4
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[0] = "
             << "__half22float2(*(half2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[1] = "
             << "__half22float2(*((half2*)(&(" << src << "))+1));\n";
      os << sret;
      return;
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
      // half8 -> float8
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[0] = "
             << "__half22float2(*(half2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[1] = "
             << "__half22float2(*((half2*)(&(" << src << "))+1));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[2] = "
             << "__half22float2(*((half2*)(&(" << src << "))+2));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[3] = "
             << "__half22float2(*((half2*)(&(" << src << "))+3));\n";
      os << sret;
      return;
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
    }
  } else if (from_ty.is_float() && target_ty.is_float16()) {
    // Use __float22half2_rn for vectorized conversion (float2 -> half2)
    if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
      // float2 -> half2
      PrintIndent();
      stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
             << src << ")));\n";
      os << sret;
      return;
    } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
      // float4 -> half4
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[0] = "
             << "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[1] = "
             << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
      os << sret;
      return;
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
      // float8 -> half8
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[0] = "
             << "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[1] = "
             << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[2] = "
             << "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n";
      PrintIndent();
      stream << "((half2*)(&" << sret << "))[3] = "
             << "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n";
      os << sret;
      return;
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    }
  }

  // Handle conversion between bfloat16 and float32
  if (from_ty.is_bfloat16() && target_ty.is_float()) {
    // Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
    if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
      // bfloat162 -> float2
      PrintIndent();
      stream << sret
             << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
             << src << ")));\n";
      os << sret;
      return;
    } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
      // bfloat162x2 -> float4
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[0] = "
             << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
             << src << ")));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[1] = "
             << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
             << src << "))+1));\n";
      os << sret;
      return;
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
      // bfloat162x4 -> float8
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[0] = "
             << "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
             << src << ")));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[1] = "
             << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
             << src << "))+1));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[2] = "
             << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
             << src << "))+2));\n";
      PrintIndent();
      stream << "((float2*)(&" << sret << "))[3] = "
             << "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
             << src << "))+3));\n";
      os << sret;
      return;
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
    }
  } else if (from_ty.is_float() && target_ty.is_bfloat16()) {
    // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
    if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
      // float2 -> bfloat162
      PrintIndent();
      stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret
             << ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
      os << sret;
      return;
    } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
      // float4 -> bfloat162x2
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
             << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
             << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
      os << sret;
      return;
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
      // float8 -> bfloat162x4
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
             << "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
             << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = "
             << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n";
      PrintIndent();
      stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = "
             << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n";
      os << sret;
      return;
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
    }
  }

  // Handle conversion from float32 to float8 (E4M3/E5M2)
  if (from_ty.is_float() &&
      (target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
    // FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
    // (float2 -> fp8x2)
    if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
      // float2 -> fp8x2
      PrintIndent();
      stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
             << ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
             << src << ")), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      os << sret;
      return;
    } else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
      // float4 -> fp8x4
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
             << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
             << ")), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
             << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
             << "))+1), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
1089
1090
      os << sret;
      return;
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
    } else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
      // float8 -> fp8x8
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
             << "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
             << ")), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
             << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
             << "))+1), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
             << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
             << "))+2), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      PrintIndent();
      stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
             << "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
             << "))+3), __NV_SATFINITE, "
             << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
             << ");\n";
      os << sret;
      return;
1119
1120
    }
  }
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132

  // Fallback: elementwise cast
  for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
    std::ostringstream val;
    val << "(";
    PrintType(target_ty.element_of(), val);
    val << ")(";
    PrintVecElemLoad(src, from_ty, i, val);
    val << ")";
    PrintVecElemStore(sret, target_ty, i, val.str());
  }

1133
1134
1135
  os << sret;
}

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) {
  // TODO(wt): Consider vectorized reduction and impl for other dtypes
  DataType t = op->dtype;

  // Standard min/max functions don't support bfloat16 or float16
  if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
    os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
       << ")";
    return;
  }

  // For float32 and float64 scalar, use standard min functions
  if (t.is_float() && t.is_scalar()) {
    if (t.bits() == 32 || t.bits() == 64) {
      os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
      return;
    }
  }

  // For all other scalar types (int, uint), use default implementation
  CodeGenC::VisitExpr_(op, os);
}

void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) {
  // TODO(wt): Consider vectorized reduction and impl for other dtypes
  DataType t = op->dtype;

  // Standard min/max functions don't support bfloat16 or float16
  if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
    os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
       << ")";
    return;
  }

  // For float32 and float64 scalar, use standard max functions
  if (t.is_float() && t.is_scalar()) {
    if (t.bits() == 32 || t.bits() == 64) {
      os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
      return;
    }
  }

  // For all other scalar types (int, uint), use default implementation
  CodeGenC::VisitExpr_(op, os);
}

1182
1183
1184
1185
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
                                          const Array<PrimExpr> &args,
                                          bool skip_first_arg,
                                          std::ostream &os) { // NOLINT(*)
1186
  DataType ret_dtype = GetRuntimeDataType(ret_type);
1187
  if (ret_dtype.is_fixed_length_vector()) {
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
1225
1226
          if (j > 0)
            scall << ", ";
1227
1228
1229
1230
1231
1232
1233
1234
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
1235
1236
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
1237
1238
1239
1240
  }
}

// Print a reference expression to a buffer.
1241
1242
1243
1244
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
                                              const BufferNode *buffer,
                                              PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }
1277
1278
1279
  if (scope.empty()) {
    scope = GetPtrStorageScope(buffer->data);
  }
1280
  if (scope == "local.var" || scope == "local.descriptor") {
1281
1282
1283
    os << vid;
    return os.str();
  }
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
std::string CodeGenTileLangCUDA::GetVecLoad(DataType t,
                                            const BufferNode *buffer,
                                            PrimExpr base) {
  const VarNode *buffer_var = buffer->data.get();
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  if (scope.empty()) {
    scope = GetPtrStorageScope(buffer->data);
  }

  if (scope != "global" || t.bits() * t.lanes() <= 128) {
    return this->CodeGenC::GetVecLoad(t, buffer, base);
  }
  ICHECK_EQ(t.bits() * t.lanes(), 256)
      << "Unsupported vector load size: " << t.bits() * t.lanes();
  auto buffer_ref = this->GetBufferRef(t, buffer, base);
  std::ostringstream os;
  os << "tl::ld_global_256(&(" << buffer_ref << "))";
  return os.str();
}

void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t,
                                        PrimExpr base,
                                        const std::string &value) {
  const VarNode *buffer_var = buffer->data.get();
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  if (scope.empty()) {
    scope = GetPtrStorageScope(buffer->data);
  }

  if (scope != "global" || t.bits() * t.lanes() <= 128) {
    this->CodeGenC::PrintVecStore(buffer, t, base, value);
    return;
  }
  ICHECK_EQ(t.bits() * t.lanes(), 256)
      << "Unsupported vector load size: " << t.bits() * t.lanes();
  auto buffer_ref = this->GetBufferRef(t, buffer, base);
  this->PrintIndent();
  this->stream << "tl::st_global_256(&(" << buffer_ref << "), " << value
               << ");\n";
}

1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
/**
 * @brief Emit CUDA/TensorLib-specific code for a call expression.
 *
 * This visitor handles CallNode intrinsics and builtins that require emitting
 * CUDA/TL-specific code (inline PTX/ASM sequences, TensorLanguage runtime
 * calls, WMMA/TMA helpers, barriers, cp.async primitives, index-map based
 * stores, reinterpret/packing helpers, and various mma/ldmatrix patterns). The
 * function writes the generated code to the provided output stream and falls
 * back to the C codegen for unrecognized calls.
 *
 * The method recognizes and emits code for (non-exhaustive): cp.async and its
 * commit/wait variants, tma_load/store and im2col variants, ptX
 * ldmatrix/stmatrix helpers, mbarrier APIs, cooperative grid sync, WMMA/legacy
 * MMA intrinsics (fill/load/store/mma/bmma/ptx_mma/ptx_mma_sp), low-level PTX
 * asm helpers (ldg32, cp_async bulk/init/arrive/wait barriers), reinterpret
 * paths for special small-float encodings (e.g., float4 e2m1fn), tl::tl_gemm
 * and related external calls, and other TL runtime calls.
 *
 * Side effects:
 * - Emits to `os` and the internal codegen output stream.
 * - May set internal feature flags (e.g., need_cooperative_groups_,
 * need_mma_h_, need_cast_smem_ptr_to_int_, enable_sparse_gemm_).
 * - May open/close SSA scopes and mutate internal variable mappings.
 * - May call LOG(FATAL) / CHECK / ICHECK on invalid or unsupported argument
 *   patterns.
 *
 * @param op The call node to generate code for; the function inspects op->op
 *           and op->args to determine the appropriate emission.
 * @param os  Output stream to receive expression-level output when the caller
 *            expects an expression result (some paths write directly to the
 *            member stream instead).
 */
1385
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
1386
1387
  auto print_extern_call_stmt = [&](std::string name, size_t start = 0,
                                    size_t end = 0) {
1388
1389
1390
1391
    // Cache context into a private ss, otherwise the let node may generate
    // within the function call arguments.
    std::ostringstream ss;

1392
1393
    for (size_t i = start; i < op->args.size() - end; i++) {
      if (i > start)
1394
1395
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
1396
    }
1397
1398
1399
1400

    this->PrintIndent();
    this->stream << name << "(";
    this->stream << ss.str();
1401
1402
    this->stream << ");\n";
  };
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
  auto print_mbarrier_obj = [&](PrimExpr barrier_id) {
    std::ostringstream ss;
    if (barrier_id.as<IntImmNode>()) {
      // incase the barrier_id is an integer, we need to print the barrier_id as
      // an integer
      ss << mbarrier_name_ << "[" << barrier_id << "]";
    } else {
      // otherwise may be a T.get_mbarrier() call or BufferLoad Node
      // we need to print the barrier_id as a string
      ss << this->PrintExpr(barrier_id);
    }
    return ss.str();
  };
1416
1417
1418
1419
1420
1421
  if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
1422
1423
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
1424
1425
    if (op->args.size() == 5) {
      this->PrintIndent();
1426
1427
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
1428
1429
1430
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
1431
1432
1433
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
  } else if (op->op.same_as(builtin::create_barriers())) {
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
1444
1445
    auto mbarrier_storage_name = mbarrier_name_ + "_mem";
    this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "["
1446
                 << barrier_count << "];\n";
1447
1448
1449
    this->PrintIndent();
    this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<"
                 << mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n";
1450
  } else if (op->op.same_as(tl::get_mbarrier())) {
1451
    ICHECK_EQ(op->args.size(), 1);
1452
    std::string barrier_id = this->PrintExpr(op->args[0]);
1453
    os << mbarrier_name_ + "[" + barrier_id + "]";
1454
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
    if (op->args.size() == 1) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      this->stream << mbarrier_obj << ".arrive();\n";
    } else if (op->args.size() == 3) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto cta_id = this->PrintExpr(op->args[1]);
      auto pred = this->PrintExpr(op->args[2]);
      this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred
                   << ");\n";
    } else {
      LOG(FATAL) << "Invalid parameter  for tl::arrive_barrier "
                 << op->args.size();
    }
1470
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
1471
1472
1473
1474
1475
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto arrive_count = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n";
1476
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
    if (op->args.size() == 2) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto transaction_bytes = this->PrintExpr(op->args[1]);
      this->stream << mbarrier_obj << ".arrive_and_expect_tx("
                   << transaction_bytes << ");\n";
    } else if (op->args.size() == 4) {
      this->PrintIndent();
      auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
      auto transaction_bytes = this->PrintExpr(op->args[1]);
      auto cta_id = this->PrintExpr(op->args[2]);
      auto pred = this->PrintExpr(op->args[3]);
      this->stream << mbarrier_obj << ".arrive_and_expect_tx("
                   << transaction_bytes << ", " << cta_id << ", " << pred
                   << ");\n";
    } else {
      LOG(FATAL) << "Invalid parameter  for tl::arrive_barrier_expect_tx "
                 << op->args.size();
    }
1496
1497
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
1498
1499
  } else if (op->op.same_as(tl::ptx_fence_barrier_init())) {
    print_extern_call_stmt("tl::fence_barrier_init");
1500
1501
  } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc");
1502
  } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
1503
1504
1505
1506
1507
1508
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto transaction_bytes = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes
                 << ");\n";
1509
  } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
1510
1511
1512
1513
1514
    ICHECK_EQ(op->args.size(), 2);
    this->PrintIndent();
    auto mbarrier_obj = print_mbarrier_obj(op->args[0]);
    auto phase = this->PrintExpr(op->args[1]);
    this->stream << mbarrier_obj << ".wait(" << phase << ");\n";
1515
1516
1517
1518
  } else if (op->op.same_as(tl::ptx_init_tensor_memory())) {
    print_extern_call_stmt("tl::tmem_allocate");
  } else if (op->op.same_as(tl::ptx_deallocate_tensor_memory())) {
    print_extern_call_stmt("tl::tmem_deallocate");
1519
1520
  } else if (op->op.same_as(tl::no_set_max_nreg())) {
    return;
1521
  } else if (op->op.same_as(tl::tma_load())) {
1522
    std::ostringstream ss;
1523
    ICHECK_GE(op->args.size(), 2);
1524
1525
1526
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
1527
1528
1529
1530
1531
1532
    // Simplify the code by using the default eviction policy
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_load<tl::CacheHintSm90::" << eviction_policy << ">(";
    } else {
      ss << "tl::tma_load(";
    }
1533
    auto desc = op->args[0];
1534
    ss << this->PrintExpr(desc) << ", ";
1535
    ss << print_mbarrier_obj(op->args[1]) << ", ";
1536
    for (size_t i = 2; i < op->args.size() - 1; i++) {
1537
      if (i > 2)
1538
1539
        ss << ", ";
      ss << this->PrintExpr(op->args[i]);
1540
    }
1541
1542
1543
    ss << ");\n";
    this->PrintIndent();
    this->stream << ss.str();
1544
  } else if (op->op.same_as(tl::tma_load_im2col())) {
1545
    std::stringstream ss;
1546
1547
1548
1549
1550
1551
1552
1553
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_load_im2col<tl::CacheHintSm90::" << eviction_policy << ">";
    } else {
      ss << "tl::tma_load_im2col";
    }
1554
    print_extern_call_stmt(ss.str(), 0, 1);
1555
  } else if (op->op.same_as(tl::tma_store())) {
1556
    std::stringstream ss;
1557
1558
1559
1560
1561
    auto need_reduce = op->args[op->args.size() - 2].as<IntImmNode>()->value;
    if (need_reduce) {
      print_extern_call_stmt("tl::tma_store_add", 0, 2);
      return;
    }
1562
1563
1564
1565
1566
1567
1568
1569
    auto eviction_policy =
        this->eviction_policy_names_
            [op->args[op->args.size() - 1].as<IntImmNode>()->value];
    if (eviction_policy != "EVICT_NORMAL") {
      ss << "tl::tma_store<tl::CacheHintSm90::" << eviction_policy << ">";
    } else {
      ss << "tl::tma_store";
    }
1570
    print_extern_call_stmt(ss.str(), 0, 2);
1571
  } else if (op->op.same_as(tl::ptx_ldmatrix())) {
1572
1573
1574
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
1575
1576
    if (trans == 1)
      func_name += "_trans";
1577
    print_extern_call_stmt(func_name, 2);
1578
  } else if (op->op.same_as(tl::ptx_stmatrix())) {
1579
1580
1581
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
1582
1583
    if (trans == 1)
      func_name += "_trans";
1584
    print_extern_call_stmt(func_name, 2);
1585
  } else if (op->op.same_as(tl::fence_proxy_async())) {
1586
    print_extern_call_stmt("tl::fence_proxy_async");
1587
  } else if (op->op.same_as(tl::tma_store_arrive())) {
1588
    print_extern_call_stmt("tl::tma_store_arrive");
1589
  } else if (op->op.same_as(tl::tma_store_wait())) {
1590
    print_extern_call_stmt("tl::tma_store_wait<0>");
1591
1592
1593
1594
1595
1596
1597
1598
1599
  } else if (op->op.same_as(tl::warpgroup_arrive())) {
    print_extern_call_stmt("tl::warpgroup_arrive");
  } else if (op->op.same_as(tl::warpgroup_commit_batch())) {
    print_extern_call_stmt("tl::warpgroup_commit_batch");
  } else if (op->op.same_as(tl::warpgroup_wait())) {
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::warpgroup_wait<" << std::to_string(num_mma)
                 << ">();\n";
1600
  } else if (op->op.same_as(tl::set_max_nreg())) {
1601
1602
1603
    this->PrintIndent();
    int nreg = Downcast<IntImm>(op->args[0])->value;
    int is_inc = Downcast<IntImm>(op->args[1])->value;
1604
1605
    std::string func_name =
        is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
1606
    this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
1607
  } else if (op->op.same_as(tl::wait_wgmma())) {
1608
1609
1610
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
1611
  } else if (op->op.same_as(tl::pack_b16())) {
1612
1613
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
1614
1615
1616
  } else if (op->op.same_as(tl::sync_grid())) {
    this->need_cooperative_groups_ = true;
    this->PrintIndent();
1617
1618
1619
1620
    this->stream << "cooperative_groups::grid_group grid = "
                    "cooperative_groups::this_grid();\n";
    this->PrintIndent();
    this->stream << "grid.sync();\n";
1621
1622
1623
  } else if (op->op.same_as(tl::loop_break())) {
    this->PrintIndent();
    this->stream << "break;\n";
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
1657
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::ptx_mma())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp64, ...
    // arg 4: B precision: fp16, fp64, ...
    // arg 5: C precision: fp32, fp64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index
    // arg 12: saturate
    // arg 13: (optional) 1-bit operator (xor or and)
    ICHECK(op->args.size() == 13U || op->args.size() == 14U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
    bool saturate = Downcast<Bool>(op->args[12])->value;
1712
1713
1714
1715
1716
    std::string bit_op =
        op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
    std::string asm_code = PrintMMAAssembly(
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
        b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
1717
    this->PrintIndent();
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_mma_sp())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp32, ...
    // arg 4: B precision: fp16, fp32, ...
    // arg 5: C precision: fp16, fp32, ...
    // arg 6: A multiplicand pointer
    // arg 7: A multiplicand index
    // arg 8: B multiplicand pointer
    // arg 9: B multiplicand index
    // arg 10: C accumulator pointer
    // arg 11: C accumulator index
    // arg 12: metadata
    // arg 13: metadata index
    // arg 14: sparse_selector
    // arg 15: saturate
    ICHECK_EQ(op->args.size(), 16U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_offset = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    std::string metadata = this->PrintExpr(op->args[12]);
    std::string metadata_offset = this->PrintExpr(op->args[13]);
    std::string sparse_selector = this->PrintExpr(op->args[14]);
    bool saturate = Downcast<Bool>(op->args[15])->value;
1753
    this->PrintIndent();
1754
    std::string asm_code = PrintMMAAssembly(
1755
1756
1757
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
        b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
        sparse_selector, "", true, saturate);
1758
    this->stream << asm_code;
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
  } else if (op->op.same_as(tl::ptx_wgmma_ss())) {
    // arg 0: dtype
    // arg 1: shape
    // arg 2: A_layout
    // arg 3: B_layout
    // arg 4: A_dtype
    // arg 5: B_dtype
    // arg 6: C_dtype
    // arg 7: multiplicand_a
    // arg 8: multiplicand_b
    // arg 9: accumulator
    // arg 10: saturate
    ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_ss args is " << op->args;
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    bool a_is_k_major = Downcast<Bool>(op->args[1])->value;
    bool b_is_k_major = Downcast<Bool>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_desc = this->PrintExpr(op->args[6]);
    std::string A_offset = this->PrintExpr(op->args[7]);
    std::string b_desc = this->PrintExpr(op->args[8]);
    std::string B_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    bool scale_out = Downcast<Bool>(op->args[12])->value;
    bool scale_in_a = Downcast<Bool>(op->args[13])->value;
    bool scale_in_b = Downcast<Bool>(op->args[14])->value;

    const bool a_is_shared = true;
    this->PrintIndent();
    std::string asm_code = PrintWGMMAAssembly(
        shape, a_is_k_major, b_is_k_major, A_dtype, B_dtype, C_dtype, a_desc,
        A_offset, b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a,
        scale_in_b, a_is_shared, "", "", "", false);
    auto [m, n, k] = tl::codegen::ptx::ParseMMAShape(shape);
    std::string wgmma_asm_code =
        "tl::wgmma_ss<(AType), (BType), (CType), (M), (N), (K), (tnspA), "
        "(tnspB), (scaleA), (scaleB)>(uint64_t((desc_a) + (A_offset)), "
        "uint64_t((desc_b) + (B_offset)), ((uint32_t*)((C))), (scale_out));\n";
    // replace patterns
    tl::codegen::Replacer replacer;
    replacer.register_rule("(AType)",
                           tl::codegen::ptx::DTypeEnumToString(A_dtype));
    replacer.register_rule("(BType)",
                           tl::codegen::ptx::DTypeEnumToString(B_dtype));
    replacer.register_rule("(CType)",
                           tl::codegen::ptx::DTypeEnumToString(C_dtype));
    replacer.register_rule("(M)", std::to_string(m));
    replacer.register_rule("(N)", std::to_string(n));
    replacer.register_rule("(K)", std::to_string(k));
    replacer.register_rule("(tnspA)", a_is_k_major ? "false" : "true");
    replacer.register_rule("(tnspB)", b_is_k_major ? "false" : "true");
    replacer.register_rule("(scaleA)", scale_in_a ? "1" : "-1");
    replacer.register_rule("(scaleB)", scale_in_b ? "1" : "-1");
    replacer.register_rule("(desc_a)", a_desc);
    replacer.register_rule("(A_offset)", A_offset);
    replacer.register_rule("(desc_b)", b_desc);
    replacer.register_rule("(B_offset)", B_offset);
    replacer.register_rule("(C)", c_ref + " + " + c_offset);
    replacer.register_rule("(scale_out)", scale_out ? "true" : "false");
    wgmma_asm_code = replacer.rewrite(wgmma_asm_code);
    this->stream << wgmma_asm_code;
  } else if (op->op.same_as(tl::ptx_wgmma_rs())) {
    // arg 0: dtype
    // arg 1: shape
    // arg 2: A_layout
    // arg 3: B_layout
    // arg 4: A_dtype
    // arg 5: B_dtype
    // arg 6: C_dtype
    // arg 7: multiplicand_a
    // arg 8: multiplicand_b
    // arg 9: accumulator
    // arg 10: saturate
    ICHECK_EQ(op->args.size(), 15U) << "ptx_wgmma_rs args is " << op->args;
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    bool A_layout = Downcast<Bool>(op->args[1])->value;
    bool B_layout = Downcast<Bool>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string A_offset = this->PrintExpr(op->args[7]);
    std::string b_desc = this->PrintExpr(op->args[8]);
    std::string B_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    bool scale_out = Downcast<Bool>(op->args[12])->value;
    bool scale_in_a = Downcast<Bool>(op->args[13])->value;
    bool scale_in_b = Downcast<Bool>(op->args[14])->value;

    const bool a_is_shared = false;
    this->PrintIndent();
    std::string asm_code = PrintWGMMAAssembly(
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, A_offset,
        b_desc, B_offset, c_ref, c_offset, scale_out, scale_in_a, scale_in_b,
        a_is_shared, "", "", "", false);
    this->stream << asm_code;
1858
1859
1860
1861
1862
1863
1864
  } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
    // arg 0: whether the matrix is loaded in column major format or not.
    // arg 1: number of matrices to load.
    // arg 2: The data type in the matrix, .b16 is the only accepted data type.
    // arg 3: pointer to local buffer.
    // arg 4: The offset of the element to store in the local buffer.
    // arg 5: pointer to the shared memory buffer to load.
1865
1866
    // arg 6: The offset of the start element of the row to load in shared
    // memory.
1867
1868
1869
1870
1871
1872
1873
1874
    ICHECK_EQ(op->args.size(), 7U);
    bool trans = Downcast<Bool>(op->args[0])->value;
    int num = Downcast<Integer>(op->args[1])->value;
    std::string type = Downcast<StringImm>(op->args[2])->value;
    std::string local_ptr = this->PrintExpr(op->args[3]);
    std::string local_elem_offset = this->PrintExpr(op->args[4]);
    std::string smem_ptr = this->PrintExpr(op->args[5]);
    if (trans && op->dtype.bits() == 8) {
1875
1876
      // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
      // properly transpose an int8 matrix.
1877
1878
1879
1880
      std::string smem_stride = this->PrintExpr(op->args[6]);
      ICHECK(num == 4);
      os << "for (int i = 0; i < 16; ++i) {\n";
      os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
1881
1882
1883
1884
         << "[(i % 8) / 4 * " + smem_stride +
                " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
                "+ (i % 4) * " + smem_stride +
                " + threadIdx.x / 4 +  (i / 8) * 8];\n";
1885
1886
1887
      os << "}\n";
    } else {
      std::string smem_elem_offset = this->PrintExpr(op->args[6]);
1888
1889
1890
1891
1892
1893
      std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
      if (trans == 1)
        func_name += "_trans";
      this->PrintIndent();
      this->stream << func_name << "(" << smem_ptr << " + " << smem_elem_offset
                   << ", " << local_ptr << " + " << local_elem_offset << ");\n";
1894
1895
1896
1897
1898
1899
1900
1901
1902
    }
  } else if (op->op.same_as(builtin::mma_store())) {
    int m = Downcast<Integer>(op->args[0])->value;
    int n = Downcast<Integer>(op->args[1])->value;
    std::string dst = this->PrintExpr(op->args[2]);
    std::string src = this->PrintExpr(op->args[3]);
    std::string src_offset = this->PrintExpr(op->args[4]);
    PrimExpr stride = op->args[5];

1903
1904
    ICHECK(m == 16 && n == 16)
        << "Only m == 16 && n == 16 case supported for now";
1905

1906
1907
1908
1909
1910
    // Each thread in a warp holds a certain number of elements of an MMA
    // output. For example, if we compute a 16x16 tile using MMA, each thread
    // holds 8 elements in its registers. So conceptually, a warp memory is
    // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
    // memory is specified by the index map below.
1911

1912
1913
    // To store the 32x8 output back to a 16x16 tile in shared or global memory,
    // we invert this map to determine the output location for each 8 element.
1914

1915
1916
    const auto index_map_func = ffi::Function::GetGlobal(
        "tir.index_map.shared_16x16_to_mma_32x8_layout");
1917

1918
1919
1920
    IndexMap index_map;
    if (!index_map_func) {
      Var i, j;
1921

1922
      // The index map is defined as follows:
1923
1924
1925
1926
1927
      index_map = IndexMap(
          {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
                   4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
    } else {
      index_map = IndexMap::FromFunc(2, *index_map_func);
1928
1929
1930
1931
1932
1933
1934
    }

    arith::Analyzer analyzer;
    auto inverse_index_map =
        index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
    auto indices_16x16 = inverse_index_map->final_indices;

1935
1936
1937
    // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
    // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
    // they reach codegen, so manually replace them to the plain ones here.
1938
    class LowerFloorDivMod : public ExprMutator {
1939
1940
    public:
      PrimExpr VisitExpr_(const FloorDivNode *op) {
1941
1942
        return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
1943
      PrimExpr VisitExpr_(const FloorModNode *op) {
1944
1945
1946
1947
        return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
    };

1948
1949
    auto dst_ind =
        LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
1950
1951
1952
1953
1954
1955
1956
1957
1958

    var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
    var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
    if (op->dtype.bits() == 16) {
      os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n";
      os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])"
         << " = "
         << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
      os << "}\n";
1959
    } else {
1960
      os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
1961
1962
      os << dst << "[" + this->PrintExpr(dst_ind) + "]" << " = " << src << "["
         << src_offset << " + local_id];\n";
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
      os << "}\n";
    }

  } else if (op->op.same_as(builtin::mma_fill())) {
    std::string num_elem = this->PrintExpr(op->args[0]);
    std::string dst = this->PrintExpr(op->args[1]);
    std::string dst_offset = this->PrintExpr(op->args[2]);

    os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
    os << dst << "[" << dst_offset << " + i] = 0.0;";
    os << "}\n";
  } else if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    need_cast_smem_ptr_to_int_ = true;
1981
1982
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
1983
    if (op->args.size() == 5) {
1984
1985
      this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
                                           size);
1986
    } else {
1987
1988
      this->stream << PrintPredicatedCpAsyncAssembly(
          dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
    }
  } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
    need_cast_smem_ptr_to_int_ = true;
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    int barrier_id = Downcast<IntImm>(op->args[5])->value;
    CHECK(barrier_id < barrier_count_);
1999
2000
2001
2002
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
                                        barrier);
2003
2004
2005
2006
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
2007
2008
    this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
                 << ";\");\n\n";
2009
2010
2011
2012
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
2013
2014
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
2015
2016
2017
2018
2019
2020
    std::string thread_count = this->PrintExpr(op->args[1]);
    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
2021
2022
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
2023
2024
2025
2026
2027
    this->stream << PrintArriveBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
2028
2029
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
2030
2031
2032
2033
2034
2035
    std::string byte_count = this->PrintExpr(op->args[1]);
    this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
  } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
2036
2037
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
    this->stream << PrintWaitBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_ldg32())) {
    /*
    asm volatile (
        "{.reg .pred p;\n"
        " setp.ne.b32 p, %2, 0;\n"
        // " @p ld.global.nc.f32 %0, [%1];}\n"t
        " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
        : "=f"(reg)
        : "l"(addr), "r"((int)guard)
    );
    */

    // get local
    std::string reg = this->PrintExpr(op->args[0]);
    // get guard
    std::string guard = this->PrintExpr(op->args[1]);
2055
    const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
    std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
    std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
    std::string local_addr = this->PrintExpr(op->args[3]);
    this->stream << "asm volatile (\n";
    this->stream << "\"{.reg .pred p;\\n\"\n";
    this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
    this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
    this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
    // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
    stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
           << ")\n";
2067
2068
    stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
           << ")), \"r\"((int)" << guard << ")\n";
2069
    stream << ");\n";
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
  } else if (op->op.same_as(builtin::reinterpret())) {
    DataType tgt_dtype = op->dtype;
    DataType src_dtype = op->args[0]->dtype;
    PrimExpr value = op->args[0];

    // Handle float4_e2m1fn reinterpret
    if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() ==
                                      src_dtype.lanes() * src_dtype.bits()) {
      return CodeGenC::VisitExpr_(op, os);
    }
    CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of lanes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;
    CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes())
        << "E2M1 float4 reinterpret expects source and target to have the same "
           "number of bytes. "
        << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype;

    int lanes = tgt_dtype.lanes();

    int ssa_scope = BeginScope();
    if (lanes == 1) {
      // The case of lane=1 is same as the normal reinterpret,
      // except that we allow the src and dst dtype to have different number of
      // bits.
      std::string rhs = SSAGetID(PrintExpr(value), src_dtype);
      os << "(*(";
      this->PrintType(tgt_dtype, os);
      os << " *)(&(" << rhs << ")))";
    } else if (lanes == 2) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint16, and then extract bits of two fp4
        // numbers, and finally reinterpret the result as fp4x2.
        value =
            tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     tir::Cast(DataType::UInt(8),
                               (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                                   ((temp_var >> 4) &
                                    IntImm(DataType::UInt(16), 0xF0))));
      } else {
        value = tir::Cast(
            DataType::UInt(16),
            tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(16));
        value =
            tir::Let(temp_var, value,
                     (temp_var & IntImm(DataType::UInt(16), 0xF)) |
                         ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else if (lanes == 4) {
      if (tgt_dtype.is_float4_e2m1fn()) {
        // We view the source as an uint32, and then extract bits of four fp4
        // numbers, and finally reinterpret the result as fp4x4.
        value =
            tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value});
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            tir::Cast(
                DataType::UInt(16),
                (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                    ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) |
                    ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) |
                    ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000))));
      } else {
        value = tir::Cast(DataType::UInt(32),
                          tir::Call(DataType::UInt(16),
                                    tir::builtin::reinterpret(), {value}));
        tir::Var temp_var("temp_var", DataType::UInt(32));
        value = tir::Let(
            temp_var, value,
            (temp_var & IntImm(DataType::UInt(32), 0xF)) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) |
                ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12));
      }
      os << PrintExpr(
          tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value}));
    } else {
      LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: "
                 << lanes;
    }
    EndScope(ssa_scope);
  } else if (op->op.same_as(builtin::thread_return())) {
    os << "return";
2164
2165
2166
2167
2168
  } else if (op->op.same_as(tl::tl_gemm())) {
    ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
                                    "A_ptr, B_ptr, C_ptr>, but got "
                                 << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
2169
2170
    this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
                          op_instance->value, op->args, true, os);
2171
2172
2173
2174
2175
2176
2177
  } else if (op->op.same_as(tl::tl_gemm_sp())) {
    ICHECK(op->args.size() == 5)
        << "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
           "E_ptr>, but got "
        << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
    enable_sparse_gemm_ = true;
2178
2179
    this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
                          op_instance->value, op->args, true, os);
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
  } else if (op->op.same_as(tl::get_lane_idx())) {
    ICHECK_LE(op->args.size(), 1)
        << "tl.get_lane_idx expects at most one argument <warp_size>.";
    os << "tl::get_lane_idx(";
    if (!op->args.empty()) {
      os << PrintExpr(op->args[0]);
    }
    os << ")";
  } else if (op->op.same_as(tl::get_warp_idx_sync())) {
    ICHECK_LE(op->args.size(), 1)
        << "tl.get_warp_idx_sync expects at most one argument <warp_size>.";
    os << "tl::get_warp_idx_sync(";
    if (!op->args.empty()) {
      os << PrintExpr(op->args[0]);
    }
    os << ")";
  } else if (op->op.same_as(tl::get_warp_idx())) {
    ICHECK_LE(op->args.size(), 1)
        << "tl.get_warp_idx expects at most one argument <warp_size>.";
    os << "tl::get_warp_idx(";
    if (!op->args.empty()) {
      os << PrintExpr(op->args[0]);
    }
    os << ")";
  } else if (op->op.same_as(tl::get_warp_group_idx())) {
    ICHECK_LE(op->args.size(), 2)
        << "tl.get_warp_group_idx expects <warp_size, warps_per_group>.";
    os << "tl::get_warp_group_idx(";
    for (size_t i = 0; i < op->args.size(); ++i) {
      if (i != 0) {
        os << ", ";
      }
      os << PrintExpr(op->args[i]);
    }
    os << ")";
2215
2216
  } else if (op->op.same_as(tl::tl_shuffle_elect())) {
    os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()";
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
  } else if (op->op.same_as(tl::initialize_descriptor())) {
    ICHECK(op->args.size() == 5)
        << "tl_initialize_descriptor expects 5 arguments but got "
        << op->args.size();
    auto descriptor = op->args[0];
    auto start_address = op->args[1];
    auto layout_type = op->args[2];
    auto leading_byte_offset = op->args[3];
    auto stride_byte_offset = op->args[4];
    os << "tl::initialize_descriptor<" << PrintExpr(layout_type) << ", "
       << PrintExpr(leading_byte_offset) << ", "
       << PrintExpr(stride_byte_offset) << ">(" << PrintExpr(descriptor) << ", "
       << PrintExpr(start_address) << ")";
  } else if (op->op.same_as(tl::increase_descriptor_offset())) {
    ICHECK(op->args.size() == 2)
        << "tl_increase_descriptor_offset expects 2 arguments but got "
        << op->args.size();
    auto descriptor = op->args[0];
    auto offset = op->args[1];
    os << "tl::increase_descriptor_offset<int>(" << PrintExpr(descriptor)
       << ", " << PrintExpr(offset) << ")";
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
  } else if (op->op.same_as(tl::__exp())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "exp");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__exp10())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "exp10");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__log())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "log");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__log2())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "log2");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__log10())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "log10");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__tan())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "tan");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__cos())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "cos");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::__sin())) {
    CUDAFastMath math_func;
    std::string func_name = math_func(op->dtype, "sin");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
  } else if (op->op.same_as(tl::ieee_add())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
    std::string func_name = math_func(op->dtype, "fadd", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ", "
       << PrintExpr(op->args[1]) << ")";
  } else if (op->op.same_as(tl::ieee_sub())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
    std::string func_name = math_func(op->dtype, "fsub", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ", "
       << PrintExpr(op->args[1]) << ")";
  } else if (op->op.same_as(tl::ieee_mul())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
    std::string func_name = math_func(op->dtype, "fmul", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ", "
       << PrintExpr(op->args[1]) << ")";
  } else if (op->op.same_as(tl::ieee_fmaf())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[3])->value;
    std::string func_name = math_func(op->dtype, "fmaf", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ", "
       << PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2]) << ")";
  } else if (op->op.same_as(tl::ieee_frcp())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[1])->value;
    std::string func_name = math_func(op->dtype, "frcp", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::ieee_fsqrt())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[1])->value;
    std::string func_name = math_func(op->dtype, "fsqrt", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::ieee_frsqrt())) {
    CUDAIEEEMath math_func;
    std::string func_name = math_func(op->dtype, "frsqrt", "rn");
    os << func_name << "(" << PrintExpr(op->args[0]) << ")";
  } else if (op->op.same_as(tl::ieee_fdiv())) {
    CUDAIEEEMath math_func;
    std::string rounding_mode = Downcast<StringImm>(op->args[2])->value;
    std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
    os << func_name << "(" << PrintExpr(op->args[0]) << ", "
       << PrintExpr(op->args[1]) << ")";
2314
2315
2316
2317
2318
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

2319
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
2320
  if (op->attr_key == tir::attr::fragment_shape) {
2321
2322
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *shape_str = op->value.as<StringImmNode>();
2323
2324
    fragment_shapes[buffer] = shape_str->value;
  } else if (op->attr_key == tir::attr::fragment_layout) {
2325
2326
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *layout_str = op->value.as<StringImmNode>();
2327
2328
    fragment_layouts[buffer] = layout_str->value;
  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
2329
2330
2331
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
2332
2333
2334
2335
2336
2337
2338
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
2339
2340
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
2341
    auto wait_cnt = wait_attrs.second;
2342
2343
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
2344
2345
2346
2347
2348
2349
2350
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
2351
    const StringImmNode *pattern = op->value.as<StringImmNode>();
2352
2353
2354
2355
2356
2357
2358
2359
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

2360
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
2361
2362
2363
2364
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());
  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
2365
  const VarNode *buffer = op->buffer_var.as<VarNode>();
2366
2367
  if (scope.find("wmma.") == 0) {
    if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
2368
2369
2370
2371
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
             op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
             op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
2372
2373
2374
          << "Matrix_a and matrix_b only support half or char or unsigned char "
          << "or uint4 or int4 or int1 type for now";
    } else {
2375
2376
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
2377
2378
2379
          << "Accumulator only support half, float and int type for now";
    }
    PrintWmmaScope(scope, op->dtype, buffer, stream);
2380
2381
  } else if (scope == "local.descriptor") {
    stream << "tl::GmmaDescriptor " << vid << ";\n";
2382
  } else {
2383
2384
2385
2386
2387
2388
2389
2390
    PrintStorageScope(scope, stream);
    PrintType(op->dtype, stream);
  }

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
2391
    ICHECK_GT(constant_size, 0)
2392
2393
        << "Can only handle constant size stack allocation for now, but get "
        << constant_size << " for " << op->buffer_var->name_hint;
2394
2395
2396
2397
2398
2399
2400
2401
    if (scope.find("wmma.") == 0) {
      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
    }
    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
2402
2403
    if (scope == "shared") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
2404
2405
2406
2407
2408
2409
    } else if (scope == "shared.barrier") {
      auto v_id_mem = vid + "_mem";
      stream << ' ' << v_id_mem << "[" << constant_size << "];\n";
      PrintIndent();
      stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_
             << "*>(" << v_id_mem << ");\n";
2410
2411
2412
    } else if (scope == "local") {
      stream << ' ' << vid << '[' << constant_size << "];\n";
    } else if (scope == "local.var") {
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
      PrimExpr init = tir::make_const(op->dtype, 0);
      auto init_it = op->annotations.find(tl::attr::kLocalVarInit);
      if (init_it != op->annotations.end()) {
        PrimExpr user_init = Downcast<PrimExpr>((*init_it).second);
        if (!user_init.dtype().is_void() && user_init.dtype() != op->dtype) {
          user_init = tir::Cast(op->dtype, user_init);
        }
        init = user_init;
      }
      stream << ' ' << vid << " = " << PrintExpr(init) << ";\n";
2423
    } else if (scope != "local.descriptor") {
2424
2425
      ICHECK(false) << "Unsupported scope: " << scope;
    }
2426
2427
2428
2429
2430
2431
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
  if (is_const_int(op->value))
    return;
  const CallNode *call = op->value.as<CallNode>();
  if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
    PrintIndent();
    stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
    PrintIndent();
    stream << "if (threadIdx.x == 0) {\n";
    PrintIndent();
    stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
    PrintIndent();
    stream << "}\n";
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
  }
  if (call && (call->op.same_as(tvm::tl::device_assert()))) {
    std::string cond = PrintExpr(call->args[0]);
    this->PrintIndent();
    stream << "device_assert(" << cond << ");\n";
  } else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
    std::string cond = PrintExpr(call->args[0]);
    std::string msg_expr = PrintExpr(call->args[1]);
    this->PrintIndent();
    stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
2455
2456
2457
2458
2459
  } else {
    CodeGenC::VisitStmt_(op);
  }
}

2460
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
2461
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
2462
2463
  CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
                     << " with " << lanes << " lanes is not allowed.";
2464
2465
2466
2467
2468
2469
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
2470
2471
    if (i != lanes - 1)
      os << ", ";
2472
2473
2474
2475
  }
  os << "))";
}

2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
                                     std::ostream &os) { // NOLINT(*)
  ICHECK_EQ(op->indices.size(), 1)
      << "Load from non-flat memory not supported.";
  ICHECK(!op->predicate.defined())
      << "Predicated buffer load is not supported.";

  DataType value_dtype = op->dtype;
  PrimExpr index = op->indices[0];
  Var buffer_var = op->buffer->data;
  DataType element_dtype = op->buffer->dtype;

  int lanes = op->dtype.lanes();
2489
  // declare type.
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
  if (value_dtype.lanes() == element_dtype.lanes()) {
    std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
    HandleVolatileLoads(ref, op, os);
  } else {
    bool can_vector_load = false;
    arith::PVar<PrimExpr> base;
    if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
      const RampNode *ramp = index.as<RampNode>();
      ICHECK(ramp);
      can_vector_load = true;
      // arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
      // The condition: {k * coeff + base} divisible by the alignment for any k
      // if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
      // == 0) {
      //   can_vector_load = true;
      // }
    }

    if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
      // A float4_e2m1fn element has 4 bits, which is an incomplete byte.
      // So we cannot vector load it.
      can_vector_load = false;
    }
    if (can_vector_load) {
      std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
      HandleVolatileLoads(ref, op, os);
    } else {
      std::ostringstream svalue_expr;
      std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
      std::string vid = GetVarID(buffer_var.get());
      DataType elem_type = op->dtype.element_of();
      for (int i = 0; i < lanes; ++i) {
        std::ostringstream value_temp;
        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
          value_temp << "((";
          if (buffer_var.get()->dtype.is_handle()) {
            auto it = alloc_storage_scope_.find(buffer_var.get());
            if (it != alloc_storage_scope_.end()) {
              PrintStorageScope(it->second, value_temp);
            }
          }
          PrintType(elem_type, value_temp);
          value_temp << "*)" << vid << ')';
        } else {
          value_temp << vid;
        }
        value_temp << '[';
        PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
        value_temp << ']';
        PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
      }
      os << svalue_expr.str();
    }
  }
}

2546
2547
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
                                     std::ostream &os) { // NOLINT(*)
2548
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) {
    if (lanes == 4) {
      // make_int8x4
      const int64_t *p = as_const_int(op->value);
      ICHECK(p);
      int64_t v = *p & 0xFF;
      v = (v << 24) | (v << 16) | (v << 8) | v;
      if (op->dtype.is_uint()) {
        os << "(uint)" << v;
      } else {
        os << "(int)" << v;
      }
      return;
    } else if (lanes == 32) {
      // make_int8x32
      const int64_t *p = as_const_int(op->value);
      ICHECK(p);
      int64_t v = *p & 0xFF;
      v = (v << 24) | (v << 16) | (v << 8) | v;
      if (op->dtype.is_uint()) {
        os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v
           << ")";
      } else {
        os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v
           << ")";
      }
      return;
2576
2577
2578
2579
2580
2581
2582
2583
    }
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
    if (lanes <= 8) {
      for (int i = 0; i < lanes / 2; ++i) {
        if (i != 0)
          os << ", ";
        os << "__pack_half2(" << v << ", " << v << ")";
      }
    } else {
      for (int i = 0; i < lanes / 4; ++i) {
        if (i != 0)
          os << ", ";
        os << "tl::pack_float16x4(" << v << ", " << v << ", " << v << ", " << v
           << ")";
      }
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
    if (lanes <= 8) {
      for (int i = 0; i < lanes / 2; ++i) {
        if (i != 0)
          os << ", ";
        os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
      }
    } else {
      for (int i = 0; i < lanes / 4; ++i) {
        if (i != 0)
          os << ", ";
        os << "tl::pack_bfloat16x4(" << v << ", " << v << ", " << v << ", " << v
           << ")";
      }
2620
2621
2622
2623
2624
    }
    os << ')';
    return;
  }

2625
2626
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
2627
2628
2629
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
2630
2631
      if (i != 0)
        os << ", ";
2632
2633
2634
2635
2636
2637
2638
2639
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
2640
    const int64_t *p = as_const_int(op->value);
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
2652
2653
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
2665
2666
          if (i != 0)
            os << ", ";
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
2689
2690
    if (i != 0)
      os << ", ";
2691
2692
2693
2694
2695
    os << v;
  }
  os << ')';
}

2696
2697
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangCUDA *p) { // NOLINT(*)
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
  // Type code is kBFloat/kFloat16
  // which is indeed CUTLASS supported types currently
  if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
      }
      temp << "std::numeric_limits<";
      p->PrintType(op->dtype, temp);
      temp << ">::infinity()";
    } else if (std::isnan(op->value)) {
      temp << "std::numeric_limits<";
      p->PrintType(op->dtype, temp);
      temp << ">::quiet_NaN()";
    } else {
      p->PrintType(op->dtype, temp);
      temp << '(' << std::hexfloat << op->value << 'f';
      temp << "/*" << std::scientific << op->value << "*/";
      temp << ')';
    }
    p->MarkConst(temp.str());
    os << temp.str();
2721
2722
    return;
  }
2723
2724
2725
  // Type code is kFloat8_e5m2 or kE4M4Float
  if (op->dtype.is_float8() || op->dtype.is_float4()) {
    p->PrintType(op->dtype, os);
2726
2727
2728
    os << '(' << std::hexfloat << op->value << 'f';
    os << "/*" << std::scientific << op->value << "*/";
    os << ')';
2729
2730
    return;
  }
2731
  // Type code is kFloat64/kFloat32 (kFloat16 is handled above)
2732
  switch (op->dtype.bits()) {
2733
2734
2735
2736
2737
2738
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
2739
      }
2740
      temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
2741
      p->need_math_constants_h_ = true;
2742
2743
    } else if (std::isnan(op->value)) {
      temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
2744
      p->need_math_constants_h_ = true;
2745
    } else {
2746
      temp << std::hexfloat << op->value;
2747
2748
      if (op->dtype.bits() == 32)
        temp << 'f';
2749
      temp << "/*" << std::scientific << op->value << "*/";
2750
    }
2751
2752
2753
2754
2755
2756
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
2757
2758
2759
  }
}

2760
2761
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
                                     std::ostream &os) { // NOLINT(*)
2762
2763
2764
  PrintConst(op, os, this);
}

2765
2766
2767
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
                                         const VarNode *variable,
                                         std::ostream &os) {
2768
2769
  std::stringstream type;
  PrintType(t, type);
2770
2771
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
  std::string shape_str = fragment_shapes.at(variable);
  if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
    type.str(std::string());
    if (t.is_int()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::s4";
      } else if (t.bits() == 1) {
        type << "nvcuda::wmma::experimental::precision::b1";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    } else if (t.is_uint()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::u4";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    }
  }
  if (scope == "wmma.matrix_a") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
2794
2795
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
2796
2797
2798
  } else if (scope == "wmma.matrix_b") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
2799
2800
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
2801
  } else if (scope == "wmma.accumulator") {
2802
2803
    os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
       << ", " << type.str() << ">";
2804
2805
2806
  }
}

2807
2808
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
                                                 const VarNode *variable,
2809
                                                 int32_t size) {
2810
2811
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
2812
2813
2814
2815
2816
2817
2818
2819
  std::string shape_str = fragment_shapes.at(variable);
  std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
  if (dim.first * dim.second != 0)
    return size / dim.first / dim.second;
  else
    return 0;
}

2820
2821
2822
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
                                              const BufferLoadNode *op,
                                              std::ostream &os) {
2823
2824
2825
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
2826
2827
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
2828
2829
2830
2831
2832
2833
2834
2835
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

2836
2837
2838
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
                                               const std::string &value,
                                               std::ostream &os) {
2839
2840
2841
2842
2843
2844
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
2845
2846
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name,
                                                 const PrimFunc &func,
                                                 std::ostream &os) {
  PrintFuncPrefix(os);
  CodeGenC::PrintType(func->ret_type, os);
  CodeGenC::PrintExtraAttrs(func, os);
  bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
  os << " " << function_name << "(";
  for (size_t i = 0; i < func->params.size(); ++i) {
    tir::Var v = func->params[i];
    std::string vid = AllocVarID(v.get());

    if (i > 0) {
      os << ", ";
    }

    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (ptr->storage_scope == "grid_constant") {
          os << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, os);
          os << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, os);
      }

      CodeGenC::PrintType(GetType(v), os);
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, os);
      }
    } else {
      CodeGenC::PrintType(GetType(v), os);
    }
    os << ' ' << vid;
  }
  os << ")";

  // Register handle data type
  // TODO(tvm-team): consider simply keep type info in the
  // type annotation(via a normalizing rewriting).
  for (const auto &param : func->params) {
    if (auto *ptr = param->type_annotation.as<PointerTypeNode>()) {
      if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
        RegisterHandleType(param.get(), prim->dtype);
      }
    }
  }
}

void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar,
                                      const PrimFunc &f) {
  // If the function has already been forward-declared, this is a
  // no-op.
  CodeGenC::DeclareFunction(gvar, f);
2969
2970
2971
2972
2973
2974
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
2975
  ICHECK(global_symbol)
2976
2977
2978
2979
2980
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
2981
2982
  this->PrintExtraAttrs(f);

2983
2984
2985
2986
2987
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
2988
2989
    if (i != 0)
      stream << ", ";
2990
2991
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
2992
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
3007
3008
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

3030
3031
} // namespace codegen
} // namespace tvm