codegen_hip.cc 41.6 KB
Newer Older
1
2
3
4
5
6
7
/*!
 * \file target/codegen.cc
 */

#include "codegen_hip.h"
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "target/source/ptx.h"

namespace tvm {
namespace codegen {

/*!
 * \brief Replace patterns with replacement strings.
 * \note should use std::format instead when codebase is ported to C++20.
 */
class Replacer {
28
29
30
public:
  void register_rule(const std::string &pattern,
                     const std::string &replacement) {
31
32
33
    _rules.emplace_back(pattern, replacement);
  }
  std::string rewrite(std::string str) {
34
    for (auto &&rule : _rules) {
35
36
37
38
39
40
41
42
43
44
45
46
47
      auto [pattern, replacement] = rule;
      size_t len = pattern.size();
      size_t new_len = replacement.size();
      size_t pos = str.find(pattern);
      while (pos != std::string::npos) {
        str = str.replace(pos, len, replacement);
        pos = str.find(pattern, pos + new_len);
      }
    }
    return str;
  }
  void empty_rules() { _rules.clear(); }

48
private:
49
50
51
52
53
  std::vector<std::pair<std::string, std::string>> _rules;
};

CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; }

54
55
56
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
57
58

class LaunchConfigExtractor : public tir::StmtVisitor {
59
60
private:
  void VisitStmt_(const AttrStmtNode *op) final {
61
62
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
63
64
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
65
        threadIdx_x_ext = op->value;
66
67
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
68
        threadIdx_y_ext = op->value;
69
70
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
71
72
73
74
75
76
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

77
public:
78
79
80
81
82
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

83
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
84
85
86
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
87
88
89
90
91
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
92
    if (threadIdx_ext_int->value == 1) {
93
94
      // unable to extract the number of threads per block, hence directly
      // return
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
      return;
    }
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
  }
}

std::string CodeGenTileLangHIP::Finish() {
  // hip must need a header file.
  decl_stream << "#include <hip/hip_runtime.h>\n";
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
  decl_stream << "#include <tl_templates/hip/gemm.h>\n";
  decl_stream << "#include <tl_templates/hip/copy.h>\n";
  decl_stream << "#include <tl_templates/hip/reduce.h>\n";
  decl_stream << "#include <tl_templates/hip/ldsm.h>\n";
  decl_stream << "#include <tl_templates/hip/threadblock_swizzle.h>\n";
112
  decl_stream << "#include <tl_templates/hip/debug.h>\n";
113
114
115
116
  decl_stream << "\n";
  return CodeGenC::Finish();
}

117
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode *op) {
118
119
120
121
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
122
123
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
124
125
126
127
128
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
129
130
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
131
132
133
134
135
136
137
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

138
void CodeGenTileLangHIP::BindThreadIndex(const IterVar &iv) {
139
  ICHECK(!var_idmap_.count(iv->var.get()));
140
141
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
142
143
}

144
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    case 16:
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
      } else {
181
        fail = true;
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
208
    }
209
210
211
212
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
213
214
215
216
217
218
219
220
221
222
223
224
225
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
    } else {
      fail = true;
    }
226
227
    if (!fail)
      return;
228
229
  } else if (t.is_float8()) {
    if (t.is_scalar()) {
230
      os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
231
    } else if (lanes == 2) {
232
233
      os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
                                  // unsigned short
234
    } else if (lanes == 4) {
235
      os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
236
237
238
    } else {
      fail = true;
    }
239
240
    if (!fail)
      return;
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
272
      }
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
296
      }
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
        os << "int4";
        return;
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
315
        break;
316
317
      } else {
        os << "char";
318
319
        break;
      }
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
343
344
        return;
      }
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
364
        fail = true;
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
      }
      return;
    }
    default:
      fail = true;
      break;
386
387
388
389
390
391
392
393
394
395
396
397
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

398
399
400
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string &op, DataType t,
                                          PrimExpr lhs, PrimExpr rhs,
                                          std::ostream &os) { // NOLINT(*)
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

434
435
436
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t,
                                          int i,
                                          std::ostream &os) { // NOLINT(*)
437
438
439
440
441
442
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
443
444
445
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
446
447
448
449
450
451
452
453
454
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
    }
  } else if (t.is_float16()) {
455
456
    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
457
  } else if (t.is_bfloat16()) {
458
459
    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
478
479
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
480
481
482
483
484
  } else {
    os << vec << "." << access[i];
  }
}

485
486
void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
                                           int i, const std::string &value) {
487
488
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
489
490
491
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
492
493
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
494
495
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
496
497
498
499
500
501
502
503
504
505
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
    }
  } else if (t.is_float16()) {
506
507
    stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << "))) = " << value << ";\n";
508
  } else if (t.is_bfloat16()) {
509
510
    stream << "*((bfloat16_t*)(&((half2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << "))) = " << value << ";\n";
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
529
530
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
531
532
533
534
535
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

536
537
void CodeGenTileLangHIP::PrintStorageSync(const CallNode *op) {
  const std::string &sync = op->args[0].as<StringImmNode>()->value;
538
539
540
541
542
543
544
545
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
  }
}

546
547
548
549
550
void CodeGenTileLangHIP::PrintStorageScope(const std::string &scope,
                                           std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
551
552
553
554
555
556
557
  if (scope == "shared") {
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

558
559
560
561
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from,
                                           DataType target) {
  if (from == target)
    return value;
562
563
564
565
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
566
567
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
568
569
570
571
572
573
574
575
576
577
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
  os << value << ")";
  return os.str();
}

578
void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) {
579
580
581
582
583
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
584
585
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
  {
    std::string src = SSAGetID(PrintExpr(op->value), from_ty);
    for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
      std::ostringstream val;
      val << "(";
      PrintType(target_ty.element_of(), val);
      val << ")(";
      PrintVecElemLoad(src, from_ty, i, val);
      val << ")";
      PrintVecElemStore(sret, target_ty, i, val.str());
    }
  }
  os << sret;
}

608
609
610
611
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol,
                                         const Array<PrimExpr> &args,
                                         bool skip_first_arg,
                                         std::ostream &os) { // NOLINT(*)
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
  DataType ret_dtype = GetRuntimeDataType(ret_type);
  if (ret_dtype.is_vector()) {
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
651
652
          if (j > 0)
            scall << ", ";
653
654
655
656
657
658
659
660
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
661
662
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
663
664
665
666
  }
}

// Print a reference expression to a buffer.
667
668
669
670
std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
                                             const BufferNode *buffer,
                                             PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }

  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

726
void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
727
728
729
730
  auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
    this->PrintIndent();
    this->stream << name << "(";
    for (size_t i = offset; i < op->args.size(); i++) {
731
732
      if (i > offset)
        this->stream << ", ";
733
734
735
736
737
738
739
740
741
742
      this->stream << this->PrintExpr(op->args[i]);
    }
    this->stream << ");\n";
  };
  if (op->op.same_as(builtin::ptx_cp_async())) {
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
743
744
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
745
746
    if (op->args.size() == 5) {
      this->PrintIndent();
747
748
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
749
750
751
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
752
753
754
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
755
756
757
758
759
760
761
762
763
764
765
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
  } else if (op->op.same_as(builtin::create_barriers())) {
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    std::string barrier_name = "_mbarrier";
766
767
    this->stream << "__shared__ uint64_t " << barrier_name << "["
                 << barrier_count << "];\n";
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
  } else if (op->op.same_as(tl::GetMBarrierOp())) {
    std::string barrier_name = "_mbarrier";
    std::string barrier_id = this->PrintExpr(op->args[0]);
    os << barrier_name + "[" + barrier_id + "]";
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
    print_extern_call_stmt("tl::mbarrier_arrive");
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
    print_extern_call_stmt("tl::mbarrier_init");
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
    print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
  } else if (op->op.same_as(tl::MBarrierExpectTX())) {
    print_extern_call_stmt("tl::mbarrier_expect_tx");
  } else if (op->op.same_as(tl::MBarrierWaitParity())) {
    print_extern_call_stmt("tl::mbarrier_wait");
  } else if (op->op.same_as(tl::SyncThreadsPartialOp())) {
    print_extern_call_stmt("tl::syncthreads_partial");
  } else if (op->op.same_as(tl::TMALoadOp())) {
    print_extern_call_stmt("tl::tma_load");
  } else if (op->op.same_as(tl::TMALoadIm2ColOp())) {
    print_extern_call_stmt("tl::tma_load_im2col");
  } else if (op->op.same_as(tl::TMAStoreOp())) {
    print_extern_call_stmt("tl::tma_store");
  } else if (op->op.same_as(tl::LDMatrixOp())) {
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num);
796
797
    if (trans == 1)
      func_name += "_trans";
798
799
800
801
802
    print_extern_call_stmt(func_name, 2);
  } else if (op->op.same_as(tl::STMatrixOp())) {
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
803
804
    if (trans == 1)
      func_name += "_trans";
805
806
807
808
809
810
811
    print_extern_call_stmt(func_name, 2);
  } else if (op->op.same_as(tl::FenceProxyAsyncOp())) {
    print_extern_call_stmt("tl::fence_proxy_async");
  } else if (op->op.same_as(tl::SetMaxNReg())) {
    this->PrintIndent();
    int nreg = Downcast<IntImm>(op->args[0])->value;
    int is_inc = Downcast<IntImm>(op->args[1])->value;
812
813
    std::string func_name =
        is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc";
814
815
816
817
818
819
    this->stream << func_name << "<" << std::to_string(nreg) << ">();\n";
  } else if (op->op.same_as(tl::WaitWgmma())) {
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
  } else if (op->op.same_as(tl::PackB16Op())) {
820
821
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
855
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
881
  } else if (op->op.same_as(builtin::tvm_mfma())) {
882
883
884
885
886
887
888
889
890
891
892
893
894
    // arg 0: prefix: {otype}_16x16x16{itype}
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: float16, float32, ...
    // arg 4: B precision: float16, float32, ...
    // arg 5: C precision: float32, float64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index

895
896
    ICHECK(op->args.size() == 12U)
        << "Invalid number of arguments for tvm_mfma";
897
898
899
900
901
902
903
904
905
906
907
908
    std::string prefix = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
909
910
    ICHECK(A_layout == "row" || B_layout == "row")
        << "Matrix core only support row major";
911
912
913
914
915
916
917
918
919
920
921
922
    // map for dtype -> float32x4 -> float4
    std::unordered_map<std::string, std::string> dtype_map = {
        {"int8", "char"},
        {"int32", "int"},
        {"int8x4", "int32_t"},
        {"int32x4", "int32x4"},
        {"float16", "half"},
        {"float32", "float"},
        {"float64", "double"},
        {"float16x4", "float16x4"},
        {"bfloat16x4", "bfloat16x4"},
        {"float32x4", "float32x4"},
923
        {"float32x16", "float32x16"}};
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    std::string call_mfma_code = R"({
    *((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
                  *((({B_dytpe}*){b_ref}) + {b_bias}),
                  *((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0);
  })";
    std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
    Replacer replacer;
    replacer.register_rule("{mfma_buildin}", mfma_buildin);
    replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
    replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
    replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]);
    replacer.register_rule("{a_ref}", a_ref);
    replacer.register_rule("{a_bias}", a_bias);
    replacer.register_rule("{b_ref}", b_ref);
    replacer.register_rule("{b_bias}", b_bias);
    replacer.register_rule("{c_ref}", c_ref);
    replacer.register_rule("{c_bias}", c_bias);
    os << replacer.rewrite(call_mfma_code);
942
  } else {
943
944
945
946
    CodeGenC::VisitExpr_(op, os);
  }
}

947
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
948
  if (op->attr_key == tir::attr::async_commit_queue_scope) {
949
950
951
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
952
953
954
955
956
957
958
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
959
960
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
961
    auto wait_cnt = wait_attrs.second;
962
963
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
964
965
966
967
968
969
970
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
971
    const StringImmNode *pattern = op->value.as<StringImmNode>();
972
973
974
975
976
977
978
979
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

980
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) {
981
982
983
984
985
986
987
988
989
990
991
992
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());

  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
  PrintStorageScope(scope, stream);
  PrintType(op->dtype, stream);

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
993
994
    ICHECK_GT(constant_size, 0)
        << "Can only handle constant size stack allocation for now";
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007

    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
    stream << ' ' << vid << '[' << constant_size << "];\n";
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

1008
void CodeGenTileLangHIP::VisitExpr_(const RampNode *op, std::ostream &os) {
1009
1010
1011
1012
1013
1014
1015
1016
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
  CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1017
1018
    if (i != lanes - 1)
      os << ", ";
1019
1020
1021
1022
  }
  os << "))";
}

1023
1024
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
                                    std::ostream &os) { // NOLINT(*)
1025
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1026
1027
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
      lanes == 4) {
1028
    // make_int8x4
1029
    const int64_t *p = as_const_int(op->value);
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
    ICHECK(p);
    int64_t v = *p & 0xFF;
    v = (v << 24) | (v << 16) | (v << 8) | v;
    if (op->dtype.is_uint()) {
      os << "(uint)" << v;
    } else {
      os << "(int)" << v;
    }
    return;
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1047
1048
      if (i != 0)
        os << ", ";
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
      os << "__pack_half2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1061
1062
      if (i != 0)
        os << ", ";
1063
1064
1065
1066
1067
1068
      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

1069
1070
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
1071
1072
1073
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
1074
1075
      if (i != 0)
        os << ", ";
1076
1077
1078
1079
1080
1081
1082
1083
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
1084
    const int64_t *p = as_const_int(op->value);
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
1096
1097
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
1109
1110
          if (i != 0)
            os << ", ";
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
1133
1134
    if (i != 0)
      os << ", ";
1135
1136
1137
1138
1139
    os << v;
  }
  os << ')';
}

1140
1141
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangHIP *p) { // NOLINT(*)
1142
1143
1144
1145
1146
1147
1148
1149
  // Type code is kBFloat
  if (op->dtype.is_bfloat16()) {
    os << "bfloat16_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
  }
  // Type code is kFloat
  switch (op->dtype.bits()) {
1150
1151
1152
1153
1154
1155
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
1156
      }
1157
1158
1159
1160
1161
1162
1163
      temp << ((op->dtype.bits() == 32) ? "HIPRT_INF_F" : "HIPRT_INF");
    } else if (std::isnan(op->value)) {
      temp << ((op->dtype.bits() == 32) ? "HIPRT_NAN_F" : "HIPRT_NAN");
    } else {
      temp << std::scientific << op->value;
      if (op->dtype.bits() == 32)
        temp << 'f';
1164
    }
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  case 16: {
    os << "half_t" << '(';
    FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
    PrintConst(const_f32.get(), os, p);
    os << ')';
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
1178
1179
1180
  }
}

1181
1182
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode *op,
                                    std::ostream &os) { // NOLINT(*)
1183
1184
1185
  PrintConst(op, os, this);
}

1186
1187
1188
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string &value,
                                             const BufferLoadNode *op,
                                             std::ostream &os) {
1189
1190
1191
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
1192
1193
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
1194
1195
1196
1197
1198
1199
1200
1201
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

1202
1203
1204
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i,
                                              const std::string &value,
                                              std::ostream &os) {
1205
1206
1207
1208
1209
1210
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
1211
1212
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

1269
void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
  ICHECK(global_symbol.defined())
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
  this->PrintExtraAttrs(f, stream);
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
1288
1289
    if (i != 0)
      stream << ", ";
1290
1291
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
1292
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
1307
1308
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

      if (no_alias) {
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

1330
1331
} // namespace codegen
} // namespace tvm