codegen_cuda.cc 56.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file target/codegen.cc
 */

#include "codegen_cuda.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
11
#include <tvm/tir/index_map.h>
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h"

namespace tvm {
namespace codegen {

26
27
28
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
  restrict_keyword_ = "__restrict__";
}
29

30
31
32
void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
33
34

class LaunchConfigExtractor : public tir::StmtVisitor {
35
36
private:
  void VisitStmt_(const AttrStmtNode *op) final {
37
38
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
39
40
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
41
        threadIdx_x_ext = op->value;
42
43
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
44
        threadIdx_y_ext = op->value;
45
46
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
47
48
49
50
51
52
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

53
public:
54
55
56
57
58
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

59
void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
60
61
62
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
63
64
65
66
67
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
68
    if (threadIdx_ext_int->value == 1) {
69
70
      // unable to extract the number of threads per block, hence directly
      // return
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
      return;
    }
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
  }
}

std::string CodeGenTileLangCUDA::Finish() {
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
  decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
  decl_stream << "#include <tl_templates/cuda/copy.h>\n";
  decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
  decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
  decl_stream << "#include <tl_templates/cuda/threadblock_swizzle.h>\n";
86
  decl_stream << "#include <tl_templates/cuda/debug.h>\n";
87
88
89
90
  decl_stream << "\n";
  return CodeGenC::Finish();
}

91
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
92
93
94
95
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
96
97
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
98
99
100
101
102
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
103
104
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
105
106
107
108
109
110
111
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

112
void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) {
113
  ICHECK(!var_idmap_.count(iv->var.get()));
114
115
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
116
117
}

118
void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    case 16:
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
      } else {
155
        fail = true;
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
182
    }
183
184
185
186
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
187
188
189
190
191
192
193
194
195
196
197
198
199
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
    } else {
      fail = true;
    }
200
201
    if (!fail)
      return;
202
203
  } else if (t.is_float8()) {
    if (t.is_scalar()) {
204
      os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
205
    } else if (lanes == 2) {
206
207
      os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
                                  // unsigned short
208
    } else if (lanes == 4) {
209
      os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
210
211
212
    } else {
      fail = true;
    }
213
214
    if (!fail)
      return;
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
246
      }
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
270
      }
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
        os << "int4";
        return;
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
289
        break;
290
291
      } else {
        os << "char";
292
293
        break;
      }
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
317
318
        return;
      }
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
338
        fail = true;
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
      }
      return;
    }
    default:
      fail = true;
      break;
360
361
362
363
364
365
366
367
368
369
370
371
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

372
373
374
void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t,
                                           PrimExpr lhs, PrimExpr rhs,
                                           std::ostream &os) { // NOLINT(*)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

408
409
410
void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
                                           int i,
                                           std::ostream &os) { // NOLINT(*)
411
412
413
414
415
416
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
417
418
419
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
420
421
422
423
424
425
426
427
428
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
    }
  } else if (t.is_float16()) {
429
430
    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
431
  } else if (t.is_bfloat16()) {
432
433
    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
452
453
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
454
455
456
457
458
  } else {
    os << vec << "." << access[i];
  }
}

459
460
void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
                                            int i, const std::string &value) {
461
462
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
463
464
465
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
466
467
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
468
469
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
470
471
472
473
474
475
476
477
478
479
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
    }
  } else if (t.is_float16()) {
480
481
    stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
482
  } else if (t.is_bfloat16()) {
483
484
    stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
           << access[i % 2] << " = " << value << ";\n";
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
503
504
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
505
506
507
508
509
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

510
511
void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
  const std::string &sync = op->args[0].as<StringImmNode>()->value;
512
513
514
515
516
517
518
519
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
  }
}

520
521
522
523
524
void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope,
                                            std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
525
526
527
528
529
530
531
  if (scope == "shared") {
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

532
533
534
535
std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
                                            DataType target) {
  if (from == target)
    return value;
536
537
538
539
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
540
541
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
542
543
544
545
546
547
548
549
550
551
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
  os << value << ")";
  return os.str();
}

552
void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
553
554
555
556
557
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
558
559
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
  {
    std::string src = SSAGetID(PrintExpr(op->value), from_ty);
    for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
      std::ostringstream val;
      val << "(";
      PrintType(target_ty.element_of(), val);
      val << ")(";
      PrintVecElemLoad(src, from_ty, i, val);
      val << ")";
      PrintVecElemStore(sret, target_ty, i, val.str());
    }
  }
  os << sret;
}

582
583
584
585
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
                                          const Array<PrimExpr> &args,
                                          bool skip_first_arg,
                                          std::ostream &os) { // NOLINT(*)
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
  DataType ret_dtype = GetRuntimeDataType(ret_type);
  if (ret_dtype.is_vector()) {
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
625
626
          if (j > 0)
            scall << ", ";
627
628
629
630
631
632
633
634
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
635
636
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
637
638
639
640
  }
}

// Print a reference expression to a buffer.
641
642
643
644
std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
                                              const BufferNode *buffer,
                                              PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }

  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

700
void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
701
702
703
704
  auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
    this->PrintIndent();
    this->stream << name << "(";
    for (size_t i = offset; i < op->args.size(); i++) {
705
706
      if (i > offset)
        this->stream << ", ";
707
708
709
710
711
712
713
714
715
716
      this->stream << this->PrintExpr(op->args[i]);
    }
    this->stream << ");\n";
  };
  if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
717
718
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
719
720
    if (op->args.size() == 5) {
      this->PrintIndent();
721
722
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
723
724
725
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
726
727
728
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
729
730
731
732
733
734
735
736
737
738
739
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
  } else if (op->op.same_as(builtin::create_barriers())) {
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    std::string barrier_name = "_mbarrier";
740
741
    this->stream << "__shared__ uint64_t " << barrier_name << "["
                 << barrier_count << "];\n";
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
  } else if (op->op.same_as(tl::GetMBarrierOp())) {
    std::string barrier_name = "_mbarrier";
    std::string barrier_id = this->PrintExpr(op->args[0]);
    os << barrier_name + "[" + barrier_id + "]";
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    print_extern_call_stmt("tl::mbarrier_arrive");
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    print_extern_call_stmt("tl::mbarrier_init");
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
  } else if (op->op.same_as(tl::MBarrierExpectTX())) {
    print_extern_call_stmt("tl::mbarrier_expect_tx");
  } else if (op->op.same_as(tl::MBarrierWaitParity())) {
    print_extern_call_stmt("tl::mbarrier_wait");
  } else if (op->op.same_as(tl::SyncThreadsPartialOp())) {
    print_extern_call_stmt("tl::syncthreads_partial");
  } else if (op->op.same_as(tl::TMALoadOp())) {
    print_extern_call_stmt("tl::tma_load");
  } else if (op->op.same_as(tl::TMALoadIm2ColOp())) {
    print_extern_call_stmt("tl::tma_load_im2col");
  } else if (op->op.same_as(tl::TMAStoreOp())) {
    print_extern_call_stmt("tl::tma_store");
  } else if (op->op.same_as(tl::LDMatrixOp())) {
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
770
771
    if (trans == 1)
      func_name += "_trans";
772
773
774
775
776
    print_extern_call_stmt(func_name, 2);
  } else if (op->op.same_as(tl::STMatrixOp())) {
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
777
778
    if (trans == 1)
      func_name += "_trans";
779
780
781
782
783
784
785
    print_extern_call_stmt(func_name, 2);
  } else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
    print_extern_call_stmt("tl::fence_proxy_async");
  } else if (op->op.same_as(tl::SetMaxNReg())) {
    this->PrintIndent();
    int nreg = Downcast<IntImm>(op->args[0])->value;
    int is_inc = Downcast<IntImm>(op->args[1])->value;
786
787
    std::string func_name =
        is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
788
789
790
791
792
793
    this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
  } else if (op->op.same_as(tl::WaitWgmma())) {
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
  } else if (op->op.same_as(tl::PackB16Op())) {
794
795
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
829
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::ptx_mma())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp64, ...
    // arg 4: B precision: fp16, fp64, ...
    // arg 5: C precision: fp32, fp64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index
    // arg 12: saturate
    // arg 13: (optional) 1-bit operator (xor or and)
    ICHECK(op->args.size() == 13U || op->args.size() == 14U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
    bool saturate = Downcast<Bool>(op->args[12])->value;
884
885
886
887
888
    std::string bit_op =
        op->args.size() > 13 ? Downcast<StringImm>(op->args[13])->value : "";
    std::string asm_code = PrintMMAAssembly(
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias,
        b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate);
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925

    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_mma_sp())) {
    // arg 0: shape: mXnXkX
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: fp16, fp32, ...
    // arg 4: B precision: fp16, fp32, ...
    // arg 5: C precision: fp16, fp32, ...
    // arg 6: A multiplicand pointer
    // arg 7: A multiplicand index
    // arg 8: B multiplicand pointer
    // arg 9: B multiplicand index
    // arg 10: C accumulator pointer
    // arg 11: C accumulator index
    // arg 12: metadata
    // arg 13: metadata index
    // arg 14: sparse_selector
    // arg 15: saturate
    ICHECK_EQ(op->args.size(), 16U);
    std::string shape = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_offset = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_offset = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_offset = this->PrintExpr(op->args[11]);
    std::string metadata = this->PrintExpr(op->args[12]);
    std::string metadata_offset = this->PrintExpr(op->args[13]);
    std::string sparse_selector = this->PrintExpr(op->args[14]);
    bool saturate = Downcast<Bool>(op->args[15])->value;
    std::string asm_code = PrintMMAAssembly(
926
927
928
        shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset,
        b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset,
        sparse_selector, "", true, saturate);
929
930
931
932
933
934
935
936
    this->stream << asm_code;
  } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
    // arg 0: whether the matrix is loaded in column major format or not.
    // arg 1: number of matrices to load.
    // arg 2: The data type in the matrix, .b16 is the only accepted data type.
    // arg 3: pointer to local buffer.
    // arg 4: The offset of the element to store in the local buffer.
    // arg 5: pointer to the shared memory buffer to load.
937
938
    // arg 6: The offset of the start element of the row to load in shared
    // memory.
939
940
941
942
943
944
945
946
    ICHECK_EQ(op->args.size(), 7U);
    bool trans = Downcast<Bool>(op->args[0])->value;
    int num = Downcast<Integer>(op->args[1])->value;
    std::string type = Downcast<StringImm>(op->args[2])->value;
    std::string local_ptr = this->PrintExpr(op->args[3]);
    std::string local_elem_offset = this->PrintExpr(op->args[4]);
    std::string smem_ptr = this->PrintExpr(op->args[5]);
    if (trans && op->dtype.bits() == 8) {
947
948
      // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
      // properly transpose an int8 matrix.
949
950
951
952
      std::string smem_stride = this->PrintExpr(op->args[6]);
      ICHECK(num == 4);
      os << "for (int i = 0; i < 16; ++i) {\n";
      os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
953
954
955
956
         << "[(i % 8) / 4 * " + smem_stride +
                " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride +
                "+ (i % 4) * " + smem_stride +
                " + threadIdx.x / 4 +  (i / 8) * 8];\n";
957
958
959
960
      os << "}\n";
    } else {
      std::string smem_elem_offset = this->PrintExpr(op->args[6]);
      need_cast_smem_ptr_to_int_ = true;
961
962
963
      this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
                                              local_elem_offset, smem_ptr,
                                              smem_elem_offset);
964
965
966
967
968
969
970
971
972
    }
  } else if (op->op.same_as(builtin::mma_store())) {
    int m = Downcast<Integer>(op->args[0])->value;
    int n = Downcast<Integer>(op->args[1])->value;
    std::string dst = this->PrintExpr(op->args[2]);
    std::string src = this->PrintExpr(op->args[3]);
    std::string src_offset = this->PrintExpr(op->args[4]);
    PrimExpr stride = op->args[5];

973
974
    ICHECK(m == 16 && n == 16)
        << "Only m == 16 && n == 16 case supported for now";
975

976
977
978
979
980
    // Each thread in a warp holds a certain number of elements of an MMA
    // output. For example, if we compute a 16x16 tile using MMA, each thread
    // holds 8 elements in its registers. So conceptually, a warp memory is
    // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of
    // memory is specified by the index map below.
981

982
983
    // To store the 32x8 output back to a 16x16 tile in shared or global memory,
    // we invert this map to determine the output location for each 8 element.
984

985
    const auto *index_map_func =
986
        runtime::Registry::Get("tir.index_map.shared_16x16_to_mma_32x8_layout");
987

988
989
990
    IndexMap index_map;
    if (!index_map_func) {
      Var i, j;
991

992
      // The index map is defined as follows:
993
994
995
996
997
      index_map = IndexMap(
          {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2),
                   4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)});
    } else {
      index_map = IndexMap::FromFunc(2, *index_map_func);
998
999
1000
1001
1002
1003
1004
    }

    arith::Analyzer analyzer;
    auto inverse_index_map =
        index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer);
    auto indices_16x16 = inverse_index_map->final_indices;

1005
1006
1007
    // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
    // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before
    // they reach codegen, so manually replace them to the plain ones here.
1008
    class LowerFloorDivMod : public ExprMutator {
1009
1010
    public:
      PrimExpr VisitExpr_(const FloorDivNode *op) {
1011
1012
        return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
1013
      PrimExpr VisitExpr_(const FloorModNode *op) {
1014
1015
1016
1017
        return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
      }
    };

1018
1019
    auto dst_ind =
        LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]);
1020
1021
1022
1023
1024
1025
1026
1027
1028

    var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
    var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
    if (op->dtype.bits() == 16) {
      os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n";
      os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])"
         << " = "
         << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n";
      os << "}\n";
1029
    } else {
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
      os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
      os << dst << "[" + this->PrintExpr(dst_ind) + "]"
         << " = " << src << "[" << src_offset << " + local_id];\n";
      os << "}\n";
    }

  } else if (op->op.same_as(builtin::mma_fill())) {
    std::string num_elem = this->PrintExpr(op->args[0]);
    std::string dst = this->PrintExpr(op->args[1]);
    std::string dst_offset = this->PrintExpr(op->args[2]);

    os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
    os << dst << "[" << dst_offset << " + i] = 0.0;";
    os << "}\n";
  } else if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    need_cast_smem_ptr_to_int_ = true;
1051
1052
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
1053
    if (op->args.size() == 5) {
1054
1055
      this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
                                           size);
1056
    } else {
1057
1058
      this->stream << PrintPredicatedCpAsyncAssembly(
          dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
    }
  } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
    need_cast_smem_ptr_to_int_ = true;
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
    int barrier_id = Downcast<IntImm>(op->args[5])->value;
    CHECK(barrier_id < barrier_count_);
1069
1070
1071
1072
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
    this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size,
                                        barrier);
1073
1074
1075
1076
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
1077
1078
    this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n
                 << ";\");\n\n";
1079
1080
1081
1082
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1083
1084
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1085
1086
1087
1088
1089
    this->stream << PrintCpAsyncBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1090
1091
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1092
1093
1094
1095
1096
1097
    std::string thread_count = this->PrintExpr(op->args[1]);
    this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1098
1099
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1100
1101
1102
1103
1104
    this->stream << PrintArriveBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1105
1106
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1107
1108
1109
1110
1111
1112
    std::string byte_count = this->PrintExpr(op->args[1]);
    this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count);
  } else if (op->op.same_as(builtin::ptx_wait_barrier())) {
    need_cast_smem_ptr_to_int_ = true;
    int barrier_id = Downcast<IntImm>(op->args[0])->value;
    CHECK(barrier_id < barrier_count_);
1113
1114
    std::string barrier =
        barrier_name_ + "[" + std::to_string(barrier_id) + "]";
1115
1116
1117
1118
1119
1120
1121
1122
    this->stream << PrintWaitBarrierAsm(barrier);
  } else if (op->op.same_as(builtin::create_barriers())) {
    CHECK_EQ(barrier_count_, -1);
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    // pad barrier alignment to avoid runtime alignment errors
    CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0);
    int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t);
    if (barrier_count % barrier_alignment_count != 0) {
1123
1124
      barrier_count = ((barrier_count / barrier_alignment_count) + 1) *
                      barrier_alignment_count;
1125
1126
    }
    barrier_count_ = barrier_count;
1127
1128
1129
1130
1131
    this->stream << "__shared__ __align__(" << barrier_alignment_bytes_
                 << ") uint64_t " << barrier_name_ << "[" << barrier_count
                 << "];\n";
    this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { "
                 << barrier_name_ << "[i] = 0; }\n";
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
  } else if (op->op.same_as(builtin::ptx_ldg32())) {
    /*
    asm volatile (
        "{.reg .pred p;\n"
        " setp.ne.b32 p, %2, 0;\n"
        // " @p ld.global.nc.f32 %0, [%1];}\n"t
        " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n"
        : "=f"(reg)
        : "l"(addr), "r"((int)guard)
    );
    */

    // get local
    std::string reg = this->PrintExpr(op->args[0]);
    // get guard
    std::string guard = this->PrintExpr(op->args[1]);
1148
    const BufferLoadNode *addr_buffer = op->args[2].as<BufferLoadNode>();
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
    std::string global_addr = this->PrintExpr(addr_buffer->indices[0]);
    std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data);
    std::string local_addr = this->PrintExpr(op->args[3]);
    this->stream << "asm volatile (\n";
    this->stream << "\"{.reg .pred p;\\n\"\n";
    this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n";
    this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n";
    this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n";
    // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ;
    stream << ": \"=f\"(" << reg << "[" << local_addr << "]"
           << ")\n";
1160
1161
    stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr
           << ")), \"r\"((int)" << guard << ")\n";
1162
1163
1164
1165
1166
1167
    stream << ");\n";
  } else {
    CodeGenC::VisitExpr_(op, os);
  }
}

1168
void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
1169
  if (op->attr_key == tir::attr::fragment_shape) {
1170
1171
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *shape_str = op->value.as<StringImmNode>();
1172
1173
    fragment_shapes[buffer] = shape_str->value;
  } else if (op->attr_key == tir::attr::fragment_layout) {
1174
1175
    const VarNode *buffer = op->node.as<VarNode>();
    const StringImmNode *layout_str = op->value.as<StringImmNode>();
1176
1177
    fragment_layouts[buffer] = layout_str->value;
  } else if (op->attr_key == tir::attr::async_commit_queue_scope) {
1178
1179
1180
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1181
1182
1183
1184
1185
1186
1187
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
1188
1189
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1190
    auto wait_cnt = wait_attrs.second;
1191
1192
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
1193
1194
1195
1196
1197
1198
1199
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
1200
    const StringImmNode *pattern = op->value.as<StringImmNode>();
1201
1202
1203
1204
1205
1206
1207
1208
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

1209
void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
1210
1211
1212
1213
1214
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());

  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
1215
  const VarNode *buffer = op->buffer_var.as<VarNode>();
1216
1217
  if (scope.find("wmma.") == 0) {
    if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
1218
1219
1220
1221
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) ||
             op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
             op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16))
1222
1223
1224
          << "Matrix_a and matrix_b only support half or char or unsigned char "
          << "or uint4 or int4 or int1 type for now";
    } else {
1225
1226
      ICHECK(op->dtype == DataType::Float(16) ||
             op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32))
1227
1228
1229
          << "Accumulator only support half, float and int type for now";
    }
    PrintWmmaScope(scope, op->dtype, buffer, stream);
1230
  } else {
1231
1232
1233
1234
1235
1236
1237
1238
    PrintStorageScope(scope, stream);
    PrintType(op->dtype, stream);
  }

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
1239
1240
    ICHECK_GT(constant_size, 0)
        << "Can only handle constant size stack allocation for now";
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
    if (scope.find("wmma.") == 0) {
      constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
    }
    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
    stream << ' ' << vid << '[' << constant_size << "];\n";
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

1256
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
1257
1258
1259
1260
1261
1262
1263
1264
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
  CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1265
1266
    if (i != lanes - 1)
      os << ", ";
1267
1268
1269
1270
  }
  os << "))";
}

1271
1272
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
                                     std::ostream &os) { // NOLINT(*)
1273
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1274
1275
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
      lanes == 4) {
1276
    // make_int8x4
1277
    const int64_t *p = as_const_int(op->value);
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    ICHECK(p);
    int64_t v = *p & 0xFF;
    v = (v << 24) | (v << 16) | (v << 8) | v;
    if (op->dtype.is_uint()) {
      os << "(uint)" << v;
    } else {
      os << "(int)" << v;
    }
    return;
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1295
1296
      if (i != 0)
        os << ", ";
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
      os << "__pack_half2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1309
1310
      if (i != 0)
        os << ", ";
1311
1312
1313
1314
1315
1316
      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

1317
1318
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
1319
1320
1321
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
1322
1323
      if (i != 0)
        os << ", ";
1324
1325
1326
1327
1328
1329
1330
1331
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
1332
    const int64_t *p = as_const_int(op->value);
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
1344
1345
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
1357
1358
          if (i != 0)
            os << ", ";
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
1381
1382
    if (i != 0)
      os << ", ";
1383
1384
1385
1386
1387
    os << v;
  }
  os << ')';
}

1388
1389
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangCUDA *p) { // NOLINT(*)
1390
1391
1392
1393
1394
1395
1396
1397
  // Type code is kBFloat
  if (op->dtype.is_bfloat16()) {
    os << "bfloat16_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
  // Type code is kFloat
  switch (op->dtype.bits()) {
1398
1399
1400
1401
1402
1403
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
1404
      }
1405
1406
1407
1408
1409
1410
1411
      temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
    } else if (std::isnan(op->value)) {
      temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
    } else {
      temp << std::scientific << op->value;
      if (op->dtype.bits() == 32)
        temp << 'f';
1412
    }
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  case 16: {
    os << "half_t" << '(';
    FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
    PrintConst(const_f32.get(), os, p);
    os << ')';
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
1426
1427
1428
  }
}

1429
1430
void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op,
                                     std::ostream &os) { // NOLINT(*)
1431
1432
1433
  PrintConst(op, os, this);
}

1434
1435
1436
void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t,
                                         const VarNode *variable,
                                         std::ostream &os) {
1437
1438
  std::stringstream type;
  PrintType(t, type);
1439
1440
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
  std::string shape_str = fragment_shapes.at(variable);
  if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
    type.str(std::string());
    if (t.is_int()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::s4";
      } else if (t.bits() == 1) {
        type << "nvcuda::wmma::experimental::precision::b1";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    } else if (t.is_uint()) {
      if (t.bits() == 4) {
        type << "nvcuda::wmma::experimental::precision::u4";
      } else {
        LOG(FATAL) << "Unhandled integer type for wmma fragment!";
      }
    }
  }
  if (scope == "wmma.matrix_a") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
1463
1464
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
1465
1466
1467
  } else if (scope == "wmma.matrix_b") {
    std::string layout_str = fragment_layouts[variable];
    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
1468
1469
    os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", "
       << type.str() << ", nvcuda::wmma::" << layout_str << ">";
1470
  } else if (scope == "wmma.accumulator") {
1471
1472
    os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str
       << ", " << type.str() << ">";
1473
1474
1475
  }
}

1476
1477
int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope,
                                                 const VarNode *variable,
1478
                                                 int32_t size) {
1479
1480
  ICHECK(fragment_shapes.count(variable))
      << "Cannot find shape of the wmma fragment " << variable->name_hint;
1481
1482
1483
1484
1485
1486
1487
1488
  std::string shape_str = fragment_shapes.at(variable);
  std::pair<int32_t, int32_t> dim = GetWmmaFragmentDimSize(shape_str, scope);
  if (dim.first * dim.second != 0)
    return size / dim.first / dim.second;
  else
    return 0;
}

1489
1490
1491
void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value,
                                              const BufferLoadNode *op,
                                              std::ostream &os) {
1492
1493
1494
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
1495
1496
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
1497
1498
1499
1500
1501
1502
1503
1504
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

1505
1506
1507
void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i,
                                               const std::string &value,
                                               std::ostream &os) {
1508
1509
1510
1511
1512
1513
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
1514
1515
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

1572
void CodeGenTileLangCUDA::AddFunction(const PrimFunc &f) {
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
  this->PrintExtraAttrs(f, stream);
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
1591
1592
    if (i != 0)
      stream << ", ";
1593
1594
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
1595
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
1610
1611
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

1633
1634
} // namespace codegen
} // namespace tvm