codegen_hip.cc 48.6 KB
Newer Older
1
2
3
4
5
6
/*!
 * \file target/codegen.cc
 */

#include "codegen_hip.h"
#include <tvm/arith/analyzer.h>
7
#include <tvm/ffi/function.h>
8
#include <tvm/tir/index_map.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
#include <tvm/tir/op.h>

#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "../op/builtin.h"
#include "target/source/ptx.h"

namespace tvm {
namespace codegen {

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
static std::string GetFP8Type(DataType type) {
  std::stringstream stream;
  int32_t lanes = type.lanes();
  std::string vec;
  if (type.is_scalar()) {
    vec = "";
  } else if (lanes == 2) {
    vec = "_2";
  } else if (lanes == 4) {
    vec = "_4";
  } else if (lanes == 8) {
    vec = "_8";
  } else if (lanes == 16) {
    vec = "_16";
  } else {
    LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
                  "for FP8";
  }
  if (type.code() == DataType::kFloat8_e4m3fn) {
    stream << "fp8_e4" << vec << "_t";
  } else if (type.code() == DataType::kFloat8_e4m3fnuz) {
    stream << "fp8_e4" << vec << "_t";
alex_xiao's avatar
alex_xiao committed
44
45
46
47
  } else if (type.code() == DataType::kFloat8_e4m3) {
    stream << "fp8_e4" << vec << "_t";
  } else if (type.code() == DataType::kFloat8_e4m3b11fnuz) {
    stream << "fp8_e4" << vec << "_t";
48
49
  } else if (type.code() == DataType::kFloat8_e5m2) {
    stream << "fp8_e5" << vec << "_t";
alex_xiao's avatar
alex_xiao committed
50
51
52
53
  } else if (type.code() == DataType::kFloat8_e5m2fnuz) {
    stream << "fp8_e5" << vec << "_t";
  } else if (type.code() == DataType::kFloat8_e8m0fnu) {
    stream << "fp8_e8" << vec << "_t";
54
  } else {
alex_xiao's avatar
alex_xiao committed
55
    LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
56
57
58
59
  }
  return stream.str();
}

60
61
62
63
64
/*!
 * \brief Replace patterns with replacement strings.
 * \note should use std::format instead when codebase is ported to C++20.
 */
class Replacer {
65
66
67
public:
  void register_rule(const std::string &pattern,
                     const std::string &replacement) {
68
69
70
    _rules.emplace_back(pattern, replacement);
  }
  std::string rewrite(std::string str) {
71
    for (auto &&rule : _rules) {
72
73
74
75
76
77
78
79
80
81
82
83
84
      auto [pattern, replacement] = rule;
      size_t len = pattern.size();
      size_t new_len = replacement.size();
      size_t pos = str.find(pattern);
      while (pos != std::string::npos) {
        str = str.replace(pos, len, replacement);
        pos = str.find(pattern, pos + new_len);
      }
    }
    return str;
  }
  void empty_rules() { _rules.clear(); }

85
private:
86
87
88
89
90
  std::vector<std::pair<std::string, std::string>> _rules;
};

CodeGenTileLangHIP::CodeGenTileLangHIP() { restrict_keyword_ = "__restrict__"; }

91
92
93
void CodeGenTileLangHIP::PrintFuncPrefix(std::ostream &os) {
  os << "extern \"C\" __global__ ";
}
94
95

class LaunchConfigExtractor : public tir::StmtVisitor {
96
97
private:
  void VisitStmt_(const AttrStmtNode *op) final {
98
99
    if (op->attr_key == tir::attr::thread_extent) {
      IterVar iv = Downcast<IterVar>(op->node);
100
101
      if (iv->var->name_hint == "threadIdx.x" ||
          iv->thread_tag == "threadIdx.x") {
102
        threadIdx_x_ext = op->value;
103
104
      } else if (iv->var->name_hint == "threadIdx.y" ||
                 iv->thread_tag == "threadIdx.y") {
105
        threadIdx_y_ext = op->value;
106
107
      } else if (iv->var->name_hint == "threadIdx.z" ||
                 iv->thread_tag == "threadIdx.z") {
108
109
110
111
112
113
        threadIdx_z_ext = op->value;
      }
    }
    StmtVisitor::VisitStmt_(op);
  }

114
public:
115
116
117
118
119
  PrimExpr threadIdx_x_ext = Integer(1);
  PrimExpr threadIdx_y_ext = Integer(1);
  PrimExpr threadIdx_z_ext = Integer(1);
};

120
void CodeGenTileLangHIP::PrintExtraAttrs(const PrimFunc &f, std::ostream &os) {
121
122
123
  LaunchConfigExtractor extractor;
  extractor(f->body);
  arith::Analyzer analyzer;
124
125
126
127
128
  PrimExpr threadIdx_ext =
      analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
                        extractor.threadIdx_z_ext);
  if (const IntImmNode *const threadIdx_ext_int =
          threadIdx_ext.as<IntImmNode>()) {
129
    if (threadIdx_ext_int->value == 1) {
130
131
      // unable to extract the number of threads per block, hence directly
      // return
132
133
134
135
136
137
138
139
      return;
    }
    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
  }
}

std::string CodeGenTileLangHIP::Finish() {
  // hip must need a header file.
Lukinon's avatar
Lukinon committed
140
  decl_stream << "#define HIP_ENABLE_WARP_SYNC_BUILTINS\n";
141
142
143
144
  decl_stream << "#include <hip/hip_runtime.h>\n";
  if (need_mma_h_) {
    decl_stream << "#include <mma.h>\n";
  }
145
146
147
148
149

  if (enable_fp8_) {
    decl_stream << "#include <tl_templates/hip/hip_fp8.h>\n";
  }

Lukinon's avatar
Lukinon committed
150
151
152
153
154
155
  decl_stream << "#include <tl_templates/dcu_hip/gemm.h>\n";
  decl_stream << "#include <tl_templates/dcu_hip/copy.h>\n";
  decl_stream << "#include <tl_templates/dcu_hip/reduce.h>\n";
  decl_stream << "#include <tl_templates/dcu_hip/ldsm.h>\n";
  decl_stream << "#include <tl_templates/dcu_hip/threadblock_swizzle.h>\n";
  decl_stream << "#include <tl_templates/dcu_hip/debug.h>\n";
156
157
158
159
  decl_stream << "\n";
  return CodeGenC::Finish();
}

160
void CodeGenTileLangHIP::VisitStmt_(const tir::ForNode *op) {
161
162
163
164
  if (op->kind == tir::ForKind::kUnrolled) {
    PrintIndent();
    stream << "#pragma unroll\n";
  }
165
166
  std::string extent =
      PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
167
168
169
170
171
  PrintIndent();
  std::string vid = AllocVarID(op->loop_var.get());
  std::string start = PrintExpr(op->min);
  stream << "for (";
  PrintType(op->loop_var.dtype(), stream);
172
173
  stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent
         << "; ++" << vid << ") {\n";
174
175
176
177
178
179
180
  int for_scope = BeginScope();
  PrintStmt(op->body);
  this->EndScope(for_scope);
  PrintIndent();
  stream << "}\n";
}

181
void CodeGenTileLangHIP::BindThreadIndex(const IterVar &iv) {
182
  ICHECK(!var_idmap_.count(iv->var.get()));
183
184
  var_idmap_[iv->var.get()] =
      CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
185
186
}

187
void CodeGenTileLangHIP::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  int lanes = t.lanes();
  if (t.is_handle()) {
    ICHECK(t.is_scalar()) << "do not yet support vector types";
    os << "void*";
    return;
  }

  if (t.is_void()) {
    os << "void";
    return;
  }

  if (t == tl::cuTensorMapType()) {
    os << "CUtensorMap";
    return;
  }

  bool fail = false;
  if (t.is_float()) {
    switch (t.bits()) {
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    case 16:
      if (t.is_scalar()) {
        os << "half_t";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp16 vector elements.
        //
        // half4 is stored as uint2
        //
        // h4.x is emitted as *(half2*)(&(u2.x)).x
        // h4.y is emitted as *(half2*)(&(u2.x)).y
        // h4.z is emitted as *(half2*)(&(u2.y)).x
        // h4.w is emitted as *(half2*)(&(u2.y)).y
        //
        ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
        os << "uint" << lanes / 2;
      } else {
224
        fail = true;
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
      }
      break;
    case 32:
      if (lanes <= 4) {
        os << "float";
      } else if (lanes <= 8) {
        // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
        //
        // float8 is stored as ulonglong4
        //
        // f8.v1 is emitted as *(float2*)(&(ul4.x)).x
        // f8.v2 is emitted as *(float2*)(&(ul4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for float type with lanes > 4";
        os << "ulonglong" << lanes / 2;
      } else {
        fail = true;
      }
      break;
    case 64:
      os << "double";
      break;
    default:
      fail = true;
      break;
251
    }
252
253
254
255
    if (!fail && (t.is_scalar() || t.bits() == 16))
      return;
    if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32))
      return;
256
257
258
259
260
261
262
263
264
265
266
267
268
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  } else if (t.is_bfloat16()) {
    if (t.is_scalar()) {
      os << "bfloat16_t";
    } else if (lanes <= 8) {
      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
      os << "uint" << lanes / 2;
    } else {
      fail = true;
    }
269
270
    if (!fail)
      return;
271
  } else if (t.is_float8()) {
272
273
274
    enable_fp8_ = true;
    os << GetFP8Type(t);
    return;
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
  } else if (t == DataType::Bool()) {
    os << "bool";
    return;
  } else if (t.is_vector_bool()) {
    // CUDA does not support bool vectors.
    // Use ushort vectors to represent instead.
    int n = t.lanes();
    if (n <= 4) {
      os << "ushort" << n;
      return;
    }
  } else if (t.is_uint() || t.is_int()) {
    if (t.is_uint()) {
      os << "u";
    }
    switch (t.bits()) {
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    case 1: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int8_t";
        return;
      } else if (t.lanes() == 16) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 32) {
        os << "int";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
306
      }
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    }
    case 4: {
      if (t.is_scalar()) {
        os << "int";
        return;
      } else if (t.lanes() == 4) {
        os << "int16_t";
        return;
      } else if (t.lanes() == 8) {
        // directly 8 4-bit int in integer.
        os << "int";
        return;
      } else if (t.lanes() == 16) {
        os << "int2";
        return;
      } else if (t.lanes() == 32) {
        os << "int4";
        return;
      } else if (t.lanes() == 64) {
        os << "int8";
        return;
      } else {
        LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
330
      }
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    }
    case 8: {
      if (t.lanes() == 4) {
        // directly 4 8 bit int in integer.

        // We use int for int8x4 instead of char4 because using char4 is
        // likely to produce extra instructions to pack four int8 elements
        // into 32-bit data.
        os << "int";
        return;
      } else if (t.lanes() == 8) {
        os << "int2";
        return;
      } else if (t.lanes() == 16) {
        os << "int4";
        return;
      } else if (!t.is_uint() && t.is_scalar()) {
        os << "signed char";
349
        break;
350
351
      } else {
        os << "char";
352
353
        break;
      }
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    }
    case 16: {
      if (t.is_scalar()) {
        os << "short";
      } else if (t.lanes() <= 4) {
        os << "short" << lanes;
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int16 vector elements.
        //
        // short4 is stored as int2
        //
        // s4.x is emitted as *(short2*)(&(i2.x)).x
        // s4.y is emitted as *(short2*)(&(i2.x)).y
        // s4.z is emitted as *(short2*)(&(i2.y)).x
        // s4.w is emitted as *(short2*)(&(i2.y)).y
        //
        ICHECK_EQ(t.lanes() % 2, 0)
            << "only support even lane for shorT type with lanes > 4";
        os << "int" << t.lanes() / 2;
      } else {
        fail = true;
      }
      if (!fail) {
377
378
        return;
      }
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
      break;
    }
    case 32: {
      if (t.is_scalar()) {
        os << "int";
      } else if (t.lanes() <= 4) {
        os << "int" << t.lanes();
      } else if (t.lanes() <= 8) {
        // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
        //
        // int8 is stored as longlong4
        //
        // i8.v1 is emitted as *(int2*)(&(l4.x)).x
        // i8.v2 is emitted as *(int2*)(&(l4.x)).y
        //
        ICHECK_EQ(lanes % 2, 0)
            << "only support even lane for int32 type with lanes > 4";
        os << "longlong" << lanes / 2;
      } else {
398
        fail = true;
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
      }
      if (!fail) {
        return;
      }
      break;
    }
    case 64: {
      if (t.is_scalar()) {
        os << "int64_t";
      } else if (t.lanes() == 2) {
        os << "longlong2";
      } else if (t.lanes() == 3) {
        os << "longlong3";
      } else if (t.lanes() == 4) {
        os << "longlong4";
      }
      return;
    }
    default:
      fail = true;
      break;
420
421
422
423
424
425
426
427
428
429
430
431
    }
    if (!fail && lanes == 1) {
      return;
    }
    if (!fail && (lanes >= 2 && lanes <= 4)) {
      os << lanes;
      return;
    }
  }
  LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}

432
433
434
void CodeGenTileLangHIP::PrintVecBinaryOp(const std::string &op, DataType t,
                                          PrimExpr lhs, PrimExpr rhs,
                                          std::ostream &os) { // NOLINT(*)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
  // Declare the result.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(t, stream);
  stream << ' ' << sret << ";\n";
  int ssa_scope = BeginScope();
  {
    // Unpack into individual ops.
    std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
    std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());

    for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
      std::ostringstream value_temp;
      if (isalpha(op[0])) {
        value_temp << op << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << ", ";
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      } else {
        value_temp << "(";
        PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
        value_temp << op;
        PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
        value_temp << ")";
      }
      PrintVecElemStore(sret, t, i, value_temp.str());
    }
  }
  EndScope(ssa_scope);
  os << sret;
}

468
469
470
void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t,
                                          int i,
                                          std::ostream &os) { // NOLINT(*)
471
472
473
474
475
476
  if (t.is_scalar()) {
    os << vec;
    return;
  }

  static const char access[] = {'x', 'y', 'z', 'w'};
477
478
479
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
480
481
482
483
484
485
486
487
488
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    std::string type_name = t.is_int() ? "char" : "unsigned char";
    if (t.lanes() == 2 || t.lanes() == 3) {
      os << vec << "." << access[i % t.lanes()];
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
    }
  } else if (t.is_float16()) {
489
490
    os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
       << access[i % 2];
491
  } else if (t.is_bfloat16()) {
492
    os << "((bfloat16x2*)(&(" << vec << "." << access[i / 2] << ")))->"
493
       << access[i % 2];
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
512
513
    os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
       << ")))->" << access[i % 2];
514
515
516
517
518
  } else {
    os << vec << "." << access[i];
  }
}

519
520
void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
                                           int i, const std::string &value) {
521
522
  this->PrintIndent();
  static const char access[] = {'x', 'y', 'z', 'w'};
523
524
525
  ICHECK(i >= 0 && i < (t.bits() == 8                        ? 16
                        : (t.bits() == 16 || t.bits() == 32) ? 8
                                                             : 4));
526
527
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (t.lanes() == 2 || t.lanes() == 3) {
528
529
      stream << vec << '.' << access[i % t.lanes()] << "="
             << "(" << value << ");\n";
530
531
532
533
534
535
536
537
538
539
    } else {
      std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
      stream << ac << "=";
      // Do not read the first undef lane.
      if (i != 0) {
        stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
      }
      stream << "(" << value << " << " << i % 4 * 8 << ");\n";
    }
  } else if (t.is_float16()) {
540
541
    stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << "))) = " << value << ";\n";
542
  } else if (t.is_bfloat16()) {
543
544
    stream << "((bfloat16_t*)(&(" << vec << "." << access[i / 2] << ")))["
           << (i % 2) << "] = " << value << ";\n";
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
  } else if (t.lanes() > 4 && t.lanes() <= 8) {
    std::string type_name;
    if (t.bits() == 16) {
      if (t.is_int()) {
        type_name = "short";
      } else if (t.is_uint()) {
        type_name = "ushort";
      }
    } else if (t.bits() == 32) {
      if (t.is_int()) {
        type_name = "int";
      } else if (t.is_uint()) {
        type_name = "uint";
      } else if (t.is_float()) {
        type_name = "float";
      }
    }
    ICHECK(!type_name.empty());
563
564
    stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
           << ")))->" << access[i % 2] << " = " << value << ";\n";
565
566
567
568
569
  } else {
    stream << vec << "." << access[i] << " = " << value << ";\n";
  }
}

570
571
void CodeGenTileLangHIP::PrintStorageSync(const CallNode *op) {
  const std::string &sync = op->args[0].as<StringImmNode>()->value;
572
573
574
575
576
577
578
579
  if (sync == "warp") {
    // DO nothing.
  } else if (sync == "shared" || sync == "shared.dyn") {
    this->PrintIndent();
    this->stream << "__syncthreads();\n";
  }
}

580
581
582
583
584
void CodeGenTileLangHIP::PrintStorageScope(const std::string &scope,
                                           std::ostream &os) { // NOLINT(*)
  ICHECK_NE(scope, "global")
      << "Cannot allocate global memory when targeting CUDA. You must pass "
         "all global arrays as input instead";
585
586
587
588
589
590
591
  if (scope == "shared") {
    os << "__shared__ ";
  } else if (scope == "shared.dyn") {
    os << "extern __shared__ __align__(1024) ";
  }
}

592
593
594
595
std::string CodeGenTileLangHIP::CastFromTo(std::string value, DataType from,
                                           DataType target) {
  if (from == target)
    return value;
596
597
598
599
  std::ostringstream os;
  os << "((";
  this->PrintType(target, os);
  os << ")";
600
601
  if (from.is_float16() && (target.is_int() || target.is_uint()) &&
      target.bits() == 8) {
602
603
604
605
606
607
608
609
610
611
    os << "(";
    if (target.is_uint()) {
      os << "u";
    }
    os << "int)";
  }
  os << value << ")";
  return os.str();
}

612
void CodeGenTileLangHIP::VisitExpr_(const CastNode *op, std::ostream &os) {
613
614
615
616
617
  DataType from_ty = op->value.dtype();
  DataType target_ty = op->dtype;
  ICHECK_EQ(target_ty.lanes(), from_ty.lanes());

  // Emit simple C-style type conversion.
618
619
  if (from_ty.is_scalar())
    return CodeGenC::VisitExpr_(op, os);
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641

  // We could emit make_float4 like calls, but the emitted code looks
  // too compact to read. Emit this as vectorized unary ops.
  std::string sret = name_supply_->FreshName("_");
  this->PrintIndent();
  this->PrintType(target_ty, stream);
  stream << ' ' << sret << ";\n";
  {
    std::string src = SSAGetID(PrintExpr(op->value), from_ty);
    for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
      std::ostringstream val;
      val << "(";
      PrintType(target_ty.element_of(), val);
      val << ")(";
      PrintVecElemLoad(src, from_ty, i, val);
      val << ")";
      PrintVecElemStore(sret, target_ty, i, val.str());
    }
  }
  os << sret;
}

642
643
644
645
void CodeGenTileLangHIP::PrintCallExtern(Type ret_type, String global_symbol,
                                         const Array<PrimExpr> &args,
                                         bool skip_first_arg,
                                         std::ostream &os) { // NOLINT(*)
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
  DataType ret_dtype = GetRuntimeDataType(ret_type);
  if (ret_dtype.is_vector()) {
    //
    // Emit an unsupported vector call
    //
    // v = intrin_f((float4*)A[0], (float4*)B[0])
    //
    // as
    //
    // float4 __ret;
    // {
    //   float4 __arg0 = ((float4*)A)[0];
    //   float4 __arg1 = ((float4*)B)[0];
    //   __ret.x = intrin_f(__arg0.x, __arg1.x);
    //   __ret.y = intrin_f(__arg0.y, __arg1.y);
    //   __ret.z = intrin_f(__arg0.z, __arg1.z);
    //   __ret.w = intrin_f(__arg0.w, __arg1.w);
    // }
    // v = __ret;
    //
    // Declare the result vector.
    std::string sret = name_supply_->FreshName("_");
    this->PrintIndent();
    this->PrintType(ret_dtype, stream);
    stream << ' ' << sret << ";\n";
    {
      // Load arguments.
      std::vector<std::string> sargs;
      size_t arg_begin = static_cast<size_t>(skip_first_arg);
      for (size_t i = arg_begin; i < args.size(); ++i) {
        std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
        sargs.push_back(std::move(val));
      }

      // Emit a scalar call for each lane.
      for (int i = 0; i < ret_dtype.lanes(); ++i) {
        std::ostringstream scall;
        scall << global_symbol << "(";
        for (size_t j = 0; j < sargs.size(); ++j) {
685
686
          if (j > 0)
            scall << ", ";
687
688
689
690
691
692
693
694
          PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
        }
        scall << ")";
        PrintVecElemStore(sret, ret_dtype, i, scall.str());
      }
    }
    os << sret;
  } else {
695
696
    CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg,
                              os);
697
698
699
700
  }
}

// Print a reference expression to a buffer.
701
702
703
704
std::string CodeGenTileLangHIP::GetBufferRef(DataType t,
                                             const BufferNode *buffer,
                                             PrimExpr index) {
  const VarNode *buffer_var = buffer->data.get();
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
  std::ostringstream os;
  std::string vid = GetVarID(buffer_var);
  std::string scope;
  if (alloc_storage_scope_.count(buffer_var)) {
    scope = alloc_storage_scope_.at(buffer_var);
  }
  // bool is_vol = IsVolatile(buffer_var);
  // always false for tl cutlass backend.
  bool is_vol = false;

  auto ptr_cast = [this, is_vol, scope](DataType pointed_to) {
    std::ostringstream ptr_os;
    ptr_os << "(";
    if (is_vol) {
      ptr_os << "volatile ";
    }
    if (!scope.empty() && IsScopePartOfType()) {
      PrintStorageScope(scope, ptr_os);
    }
    PrintType(pointed_to, ptr_os);
    ptr_os << "*)";
    return ptr_os.str();
  };

  DataType buffer_element_dtype = buffer->dtype;

  std::string buffer_str = vid;
  if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) {
    std::stringstream temp;
    temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")";
    buffer_str = temp.str();
  }

  std::string index_str = PrintExpr(index);
  if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
    // This is a special case, because CodegenCUDA::PrintType()
    // returns "int" for bool and for 4-bit integers. In most cases,
    // we divide by the number of lanes to determine the index.
    // However, the backing type for scalar int4 and scalar bool is
    // int32.  Therefore, we need to divide by the ratio of their
    // sizes in that case.
    int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes();

    os << "*("
       << "(" << ptr_cast(t) << vid << ")"
       << " + " << index_str << " / " << div_factor << ")";
  } else if (t == buffer_element_dtype) {
    os << buffer_str << "[" << index_str << "]";
  } else {
    os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")";
  }

  return os.str();
}

760
void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
761
  auto print_extern_call_stmt = [&](std::string name, size_t offset = 0) {
762
    printf("[DEBUG VisitExpr_] Branch: print_extern_call_stmt -> %s\n", name.c_str());
763
764
765
    this->PrintIndent();
    this->stream << name << "(";
    for (size_t i = offset; i < op->args.size(); i++) {
766
767
      if (i > offset)
        this->stream << ", ";
768
769
770
771
      this->stream << this->PrintExpr(op->args[i]);
    }
    this->stream << ");\n";
  };
772

773
  if (op->op.same_as(builtin::ptx_cp_async())) {
774
    printf("[DEBUG VisitExpr_] Branch: ptx_cp_async\n");
775
776
777
778
779
    std::string dst = this->PrintExpr(op->args[0]);
    std::string dst_offset = this->PrintExpr(op->args[1]);
    std::string src = this->PrintExpr(op->args[2]);
    std::string src_offset = this->PrintExpr(op->args[3]);
    std::string size = this->PrintExpr(op->args[4]);
780
781
    // use size of argument list to indicate whether or not to use predicated
    // cp.async
782
783
    if (op->args.size() == 5) {
      this->PrintIndent();
784
785
      this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
                   << dst_offset << ", " << src << "+" << src_offset << ");\n";
786
787
788
    } else {
      std::string condition = this->PrintExpr(op->args[5]);
      this->PrintIndent();
789
790
791
      this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
                   << "+" << dst_offset << ", " << src << "+" << src_offset
                   << ", " << condition << ");\n";
792
793
    }
  } else if (op->op.same_as(builtin::ptx_commit_group())) {
794
    printf("[DEBUG VisitExpr_] Branch: ptx_commit_group\n");
795
796
    print_extern_call_stmt("tl::cp_async_commit");
  } else if (op->op.same_as(builtin::ptx_wait_group())) {
797
    printf("[DEBUG VisitExpr_] Branch: ptx_wait_group\n");
798
799
800
    int n = Downcast<IntImm>(op->args[0])->value;
    std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">";
    print_extern_call_stmt(func_name, 1);
801
  } else if (op->op.same_as(builtin::create_barriers())) {
802
    printf("[DEBUG VisitExpr_] Branch: create_barriers\n");
803
804
805
806
807
808
    this->PrintIndent();
    int barrier_count = Downcast<IntImm>(op->args[0])->value;
    std::string barrier_name = "_mbarrier";
    this->stream << "__shared__ uint64_t " << barrier_name << "["
                 << barrier_count << "];\n";
  } else if (op->op.same_as(tl::get_mbarrier())) {
809
    printf("[DEBUG VisitExpr_] Branch: get_mbarrier\n");
810
811
812
813
    std::string barrier_name = "_mbarrier";
    std::string barrier_id = this->PrintExpr(op->args[0]);
    os << barrier_name + "[" + barrier_id + "]";
  } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
814
    printf("[DEBUG VisitExpr_] Branch: ptx_arrive_barrier\n");
815
816
    print_extern_call_stmt("tl::mbarrier_arrive");
  } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) {
817
    printf("[DEBUG VisitExpr_] Branch: ptx_init_barrier_thread_count\n");
818
819
    print_extern_call_stmt("tl::mbarrier_init");
  } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) {
820
    printf("[DEBUG VisitExpr_] Branch: ptx_arrive_barrier_expect_tx\n");
821
822
    print_extern_call_stmt("tl::mbarrier_arrive_expect_tx");
  } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
823
    printf("[DEBUG VisitExpr_] Branch: ptx_cp_async_barrier\n");
824
825
    print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
  } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
826
    printf("[DEBUG VisitExpr_] Branch: mbarrier_expect_tx\n");
827
828
    print_extern_call_stmt("tl::mbarrier_expect_tx");
  } else if (op->op.same_as(tl::mbarrier_wait_parity())) {
829
    printf("[DEBUG VisitExpr_] Branch: mbarrier_wait_parity\n");
830
    print_extern_call_stmt("tl::mbarrier_wait");
831
  } else if (op->op.same_as(tl::ptx_stmatrix())) {
832
    printf("[DEBUG VisitExpr_] Branch: ptx_stmatrix\n");
833
834
835
    int trans = Downcast<IntImm>(op->args[0])->value;
    int num = Downcast<IntImm>(op->args[1])->value;
    std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num);
836
837
    if (trans == 1)
      func_name += "_trans";
838
    print_extern_call_stmt(func_name, 2);
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
  }else if(op->op.same_as(tl::ds_read_vector())){
    //ds_read_b64 %1, %2 offset:%3
    // ds_read_m32x16_b16 %0, %1 offset:%2
    printf("[DEBUG VisitExpr_] Branch: ds_read_vector\n");
    std::string dst = this->PrintExpr(op->args[0]);
    std::string lds_base_ptr = this->PrintExpr(op->args[1]);
    std::string m = this->PrintExpr(op->args[2]);
    std::string n = this->PrintExpr(op->args[3]);
    std::string offset = this->PrintExpr(op->args[4]);
    this->PrintIndent();
    this->stream << "tl::ds_read_vector<" << m << ", " << n  <<", " << offset << ">"
                 << "(*reinterpret_cast<float4_*>(" << dst << "), "
                 << "reinterpret_cast<uintptr_t>(" << lds_base_ptr << "));\n";
  }else if (op->op.same_as(tl::wait_wgmma())) {
    printf("[DEBUG VisitExpr_] Branch: wait_wgmma\n");
854
855
856
    this->PrintIndent();
    int num_mma = Downcast<IntImm>(op->args[0])->value;
    this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n";
857
  } else if (op->op.same_as(tl::pack_b16())) {
858
    printf("[DEBUG VisitExpr_] Branch: pack_b16\n");
859
860
    os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
       << this->PrintExpr(op->args[1]) << ")";
861
  } else if (op->op.same_as(tl::__ldg())) {
862
    printf("[DEBUG VisitExpr_] Branch: __ldg\n");
863
864
865
866
867
868
869
870
871
    // HIP fallback: regular load
    const BufferLoadNode *bl = op->args[0].as<BufferLoadNode>();
    ICHECK(bl) << "T.__ldg expects a BufferLoad as the first argument.";
    ICHECK_EQ(bl->indices.size(), 1)
        << "T.__ldg currently supports flattened 1D buffer accesses.";
    const BufferNode *buffer = bl->buffer.get();
    PrimExpr base = bl->indices[0];
    auto buffer_ref = this->GetBufferRef(op->dtype, buffer, base);
    os << buffer_ref;
872
  } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
873
    printf("[DEBUG VisitExpr_] Branch: tvm_fill_fragment\n");
874
875
876
877
878
879
880
881
882
883
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 6U);
    os << "nvcuda::wmma::fill_fragment(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
884
    printf("[DEBUG VisitExpr_] Branch: tvm_load_matrix_sync\n");
885
886
887
888
889
890
891
892
893
894
895
896
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::load_matrix_sync(";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[6], os);
    os << ")";
  } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
897
    printf("[DEBUG VisitExpr_] Branch: tvm_store_matrix_sync\n");
898
899
900
901
902
903
904
905
906
907
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::store_matrix_sync(";
    this->PrintExpr(op->args[5], os);
    os << ", ";
    this->PrintExpr(op->args[0], os);
    os << "[";
    this->PrintExpr(op->args[4], os);
    os << "], ";
    this->PrintExpr(op->args[6], os);
908
    if (const StringImmNode *str = op->args[7].as<StringImmNode>()) {
909
910
911
912
913
914
      os << ", nvcuda::wmma::mem_" << str->value;
    } else {
      LOG(FATAL) << "Invalid parameters";
    }
    os << ")";
  } else if (op->op.same_as(builtin::tvm_mma_sync())) {
915
    printf("[DEBUG VisitExpr_] Branch: tvm_mma_sync\n");
916
917
918
919
920
921
922
923
924
925
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::mma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
  } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
926
    printf("[DEBUG VisitExpr_] Branch: tvm_bmma_sync\n");
927
928
929
930
931
932
933
934
935
    need_mma_h_ = true;
    ICHECK_EQ(op->args.size(), 8U);
    os << "nvcuda::wmma::bmma_sync(";
    for (int i = 0; i < 4; ++i) {
      this->PrintExpr(op->args[i * 2], os);
      os << "[";
      this->PrintExpr(op->args[i * 2 + 1], os);
      os << "]" << ((i < 3) ? ", " : ")");
    }
936
  } else if (op->op.same_as(tl::tvm_mfma())) {
937
    printf("[DEBUG VisitExpr_] Branch: tvm_mfma\n");
938
    // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
939
940
941
942
943
944
945
946
947
948
949
950
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: float16, float32, ...
    // arg 4: B precision: float16, float32, ...
    // arg 5: C precision: float32, float64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index

951
952
    ICHECK(op->args.size() == 12U)
        << "Invalid number of arguments for tvm_mfma";
953
954
955
956
957
958
959
960
961
962
963
964
    std::string prefix = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
965
966
    ICHECK(A_layout == "row" || B_layout == "row")
        << "Matrix core only support row major";
967
968
969
970
971
    // map for dtype -> float32x4 -> float4
    std::unordered_map<std::string, std::string> dtype_map = {
        {"int8", "char"},
        {"int32", "int"},
        {"int8x4", "int32_t"},
972
        {"int8x8", "int64_t"},
973
974
975
976
977
        {"int32x4", "int32x4"},
        {"float16", "half"},
        {"float32", "float"},
        {"float64", "double"},
        {"float16x4", "float16x4"},
978
        {"bfloat16x4", "bfloat16x4_vec"},
979
        {"float32x4", "float32x4"},
980
981
        {"float8_e4m3fnuzx4", "fp8_e4_4_t"},
        {"float8_e4m3fnuzx8", "long"},
982
        {"float32x16", "float32x16"}};
983
    std::string call_mfma_code = R"({
alex_xiao's avatar
alex_xiao committed
984
985
986
987
      *((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
                    *((({B_dtype}*){b_ref}) + {b_bias}),
                    *((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
    })";
988
989
    std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
    Replacer replacer;
990

991
    replacer.register_rule("{mfma_buildin}", mfma_buildin);
992
993
994
    replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
    replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
    replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
995
996
997
998
999
1000
1001
    replacer.register_rule("{a_ref}", a_ref);
    replacer.register_rule("{a_bias}", a_bias);
    replacer.register_rule("{b_ref}", b_ref);
    replacer.register_rule("{b_bias}", b_bias);
    replacer.register_rule("{c_ref}", c_ref);
    replacer.register_rule("{c_bias}", c_bias);
    os << replacer.rewrite(call_mfma_code);
Lukinon's avatar
Lukinon committed
1002
  } else if (op->op.same_as(tl::tvm_mmac())) {
1003
    printf("[DEBUG VisitExpr_] Branch: tvm_mmac\n");
Lukinon's avatar
Lukinon committed
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    // arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
    // arg 1: A layout: row/col
    // arg 2: B layout: row/col
    // arg 3: A precision: float16, float32, ...
    // arg 4: B precision: float16, float32, ...
    // arg 5: C precision: float32, float64, ...
    // arg 6: A multiplicand
    // arg 7: A multiplicand index
    // arg 8: B multiplicand
    // arg 9: B multiplicand index
    // arg 10: C accumulator
    // arg 11: C accumulator index

    ICHECK(op->args.size() == 12U)
qisan's avatar
qisan committed
1018
        << "Invalid number of arguments for tvm_mmac";
Lukinon's avatar
Lukinon committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    std::string prefix = Downcast<StringImm>(op->args[0])->value;
    std::string A_layout = Downcast<StringImm>(op->args[1])->value;
    std::string B_layout = Downcast<StringImm>(op->args[2])->value;
    std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
    std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
    std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
    std::string a_ref = this->PrintExpr(op->args[6]);
    std::string a_bias = this->PrintExpr(op->args[7]);
    std::string b_ref = this->PrintExpr(op->args[8]);
    std::string b_bias = this->PrintExpr(op->args[9]);
    std::string c_ref = this->PrintExpr(op->args[10]);
    std::string c_bias = this->PrintExpr(op->args[11]);
    ICHECK(A_layout == "row" || B_layout == "row")
        << "Matrix core only support row major";
    // map for dtype -> float32x4 -> float4
    std::unordered_map<std::string, std::string> dtype_map = {
        {"int8", "char"},
        {"int32", "int"},
        {"int8x4", "int32_t"},
        {"int8x8", "int64_t"},
        {"int32x4", "int32x4"},
        {"float16", "half"},
        {"float32", "float"},
        {"float64", "double"},
        {"float16x4", "float16x4"},
        {"bfloat16x4", "bfloat16x4"},
        {"float32x4", "float32x4"},
        {"float8_e4m3fnuzx4", "fp8_e4_4_t"},
        {"float8_e4m3fnuzx8", "long"},
        {"float32x16", "float32x16"}};
    std::string call_mmac_code = R"({
    *((({C_dtype}*){c_ref}) + {c_bias}) = {mmac_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
                  *((({B_dtype}*){b_ref}) + {b_bias}),
                  *((({C_dtype}*){c_ref}) + {c_bias}));
  })";
    std::string mmac_buildin = "__builtin_amdgcn_mmac_" + prefix;
    Replacer replacer;

    replacer.register_rule("{mmac_buildin}", mmac_buildin);
    replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
    replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
    replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
    replacer.register_rule("{a_ref}", a_ref);
    replacer.register_rule("{a_bias}", a_bias);
    replacer.register_rule("{b_ref}", b_ref);
    replacer.register_rule("{b_bias}", b_bias);
    replacer.register_rule("{c_ref}", c_ref);
    replacer.register_rule("{c_bias}", c_bias);
    os << replacer.rewrite(call_mmac_code);
1068
  } else if (op->op.same_as(builtin::thread_return())) {
1069
    printf("[DEBUG VisitExpr_] Branch: thread_return\n");
1070
1071
    os << "return";
  } else if (op->op.same_as(tl::tl_gemm())) {
1072
    printf("[DEBUG VisitExpr_] Branch: tl_gemm\n");
1073
1074
1075
1076
    ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
                                    "A_ptr, B_ptr, C_ptr>, but got "
                                 << op->args.size();
    auto op_instance = Downcast<StringImm>(op->args[0]);
1077
1078
    this->PrintCallExtern(GetType(tvm::ffi::GetRef<PrimExpr>(op)),
                          op_instance->value, op->args, true, os);
1079
  } else if (op->op.same_as(tl::tl_gemm_sp())) {
1080
    printf("[DEBUG VisitExpr_] Branch: tl_gemm_sp\n");
1081
    LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
alex_xiao's avatar
alex_xiao committed
1082
  } else if (op->op.same_as(tl::loop_break())) {
1083
    printf("[DEBUG VisitExpr_] Branch: loop_break\n");
alex_xiao's avatar
alex_xiao committed
1084
1085
1086
    this->PrintIndent();
    this->stream << "break;\n";
  } else if (op->op.same_as(tl::no_set_max_nreg())) {
1087
    printf("[DEBUG VisitExpr_] Branch: no_set_max_nreg\n");
alex_xiao's avatar
alex_xiao committed
1088
1089
1090
    // HIP doesn't need explicit register management like CUDA
    // This is a no-op for HIP
    return;
1091
  } else {
1092
    printf("[DEBUG VisitExpr_] Branch: CodeGenC::VisitExpr_ (fallback)\n");
1093
1094
1095
1096
    CodeGenC::VisitExpr_(op, os);
  }
}

1097
void CodeGenTileLangHIP::VisitStmt_(const AttrStmtNode *op) {
1098
  if (op->attr_key == tir::attr::async_commit_queue_scope) {
1099
1100
1101
    const IntImmNode *queue_id = op->value.as<IntImmNode>();
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1102
1103
1104
1105
1106
1107
1108
    this->VisitStmt(op->body);
    auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
    this->VisitExpr(commit_group, this->stream);
    return;
  } else if (op->attr_key == tir::attr::async_wait_queue_scope) {
    auto wait_attrs = GetAsyncWaitAttributes(op);
    auto queue_id = wait_attrs.first.as<IntImmNode>();
1109
1110
    ICHECK(queue_id && queue_id->value == 0)
        << "For CUDA, the index of an async queue must be 0.";
1111
    auto wait_cnt = wait_attrs.second;
1112
1113
    auto wait_group =
        Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
1114
1115
1116
1117
1118
1119
1120
    this->VisitExpr(wait_group, this->stream);
    auto inner = op->body.as<AttrStmtNode>();
    ICHECK(inner);
    this->VisitStmt(inner->body);
    return;
  } else if (op->attr_key == "threadblock_swizzle_pattern") {
    this->PrintIndent();
1121
    const StringImmNode *pattern = op->value.as<StringImmNode>();
1122
1123
1124
1125
1126
1127
1128
1129
    ICHECK(pattern);
    this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
    this->VisitStmt(op->body);
    return;
  }
  CodeGenC::VisitStmt_(op);
}

1130
void CodeGenTileLangHIP::VisitStmt_(const AllocateNode *op) {
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
  ICHECK(!is_zero(op->condition));
  std::string vid = AllocVarID(op->buffer_var.get());

  this->PrintIndent();
  std::string scope = GetPtrStorageScope(op->buffer_var);
  PrintStorageScope(scope, stream);
  PrintType(op->dtype, stream);

  if (scope == "shared.dyn") {
    stream << ' ' << vid << "[];\n";
  } else {
    size_t constant_size = op->ConstantAllocationSize();
1143
1144
    ICHECK_GT(constant_size, 0)
        << "Can only handle constant size stack allocation for now";
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157

    if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
         op->dtype == DataType::Int(1)) &&
        scope == "shared") {
      constant_size = constant_size / (32 / op->dtype.bits());
    }
    stream << ' ' << vid << '[' << constant_size << "];\n";
  }

  RegisterHandleType(op->buffer_var.get(), op->dtype);
  this->PrintStmt(op->body);
}

1158
void CodeGenTileLangHIP::VisitExpr_(const RampNode *op, std::ostream &os) {
1159
1160
1161
1162
1163
1164
1165
1166
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
  CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
  os << "(make_";
  PrintType(op->dtype, os);
  os << "(";
  for (int i = 0; i < lanes; i++) {
    os << "(" << PrintExpr(op->base) << ")"
       << "+(" << PrintExpr(op->stride) << "*" << i << ")";
1167
1168
    if (i != lanes - 1)
      os << ", ";
1169
1170
1171
1172
  }
  os << "))";
}

1173
1174
void CodeGenTileLangHIP::VisitExpr_(const BroadcastNode *op,
                                    std::ostream &os) { // NOLINT(*)
1175
  int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
1176
1177
  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
      lanes == 4) {
1178
    // make_int8x4
1179
    const int64_t *p = as_const_int(op->value);
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
    ICHECK(p);
    int64_t v = *p & 0xFF;
    v = (v << 24) | (v << 16) | (v << 8) | v;
    if (op->dtype.is_uint()) {
      os << "(uint)" << v;
    } else {
      os << "(int)" << v;
    }
    return;
  }

  if (op->dtype.is_float16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1197
1198
      if (i != 0)
        os << ", ";
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
      os << "__pack_half2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if (op->dtype.is_bfloat16()) {
    std::string v = PrintExpr(op->value);
    os << "make_";
    PrintType(op->dtype, os);
    os << '(';
    for (int i = 0; i < lanes / 2; ++i) {
1211
1212
      if (i != 0)
        os << ", ";
1213
      os << "__pack_bfloat162(" << v << ", " << v << ")";
1214
1215
1216
1217
1218
    }
    os << ')';
    return;
  }

1219
1220
  if (op->dtype.is_float() && op->dtype.bits() == 32 &&
      op->dtype.lanes() == 8) {
1221
1222
1223
    std::string v = PrintExpr(op->value);
    os << "make_ulonglong4(";
    for (int i = 0; i < 4; ++i) {
1224
1225
      if (i != 0)
        os << ", ";
1226
1227
1228
1229
1230
1231
1232
1233
      os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")";
    }
    os << ')';
    return;
  }

  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
    bool fail = false;
1234
    const int64_t *p = as_const_int(op->value);
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
    ICHECK(p);
    int64_t v = *p & 0xF;

    if (lanes == 4) {
      v = (v << 12) | (v << 8) | (v << 4) | v;
      if (op->dtype.is_uint()) {
        os << "(uint16_t)" << v;
      } else {
        os << "(int16_t)" << v;
      }
    } else {
1246
1247
      v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) |
          (v << 4) | v;
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
      if (lanes == 8) {
        if (op->dtype.is_uint()) {
          os << "(uint)" << v;
        } else {
          os << "(int)" << v;
        }
      } else if (lanes == 16 || lanes == 32) {
        os << "make_";
        PrintType(op->dtype, os);
        os << '(';
        for (int i = 0; i < lanes / 8; ++i) {
1259
1260
          if (i != 0)
            os << ", ";
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
          if (op->dtype.is_uint()) {
            os << "(uint)" << v;
          } else {
            os << "(int)" << v;
          }
        }
        os << ')';
      } else {
        fail = true;
      }
    }

    if (!fail) {
      return;
    }
  }

  std::string v = PrintExpr(op->value);
  os << "make_";
  PrintType(op->dtype, os);
  os << '(';
  for (int i = 0; i < lanes; ++i) {
1283
1284
    if (i != 0)
      os << ", ";
1285
1286
1287
1288
1289
    os << v;
  }
  os << ')';
}

1290
1291
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
                       CodeGenTileLangHIP *p) { // NOLINT(*)
1292
1293
1294
1295
1296
  // Type code is kBFloat
  if (op->dtype.is_bfloat16()) {
    os << "bfloat16_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
alex_xiao's avatar
alex_xiao committed
1297
1298
  } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
             op->dtype.is_float8_e4m3fn()) {
1299
1300
1301
    os << "fp8_e4_t";
    os << '(' << std::scientific << op->value << 'f' << ')';
    return;
1302
1303
1304
  }
  // Type code is kFloat
  switch (op->dtype.bits()) {
1305
1306
1307
1308
1309
1310
  case 64:
  case 32: {
    std::ostringstream temp;
    if (std::isinf(op->value)) {
      if (op->value < 0) {
        temp << "-";
1311
      }
1312
      temp << ((op->dtype.bits() == 32) ? "HUGE_VALF" : "HUGE_VAL");
1313
    } else if (std::isnan(op->value)) {
1314
      temp << ((op->dtype.bits() == 32) ? "NAN" : "NAN");
1315
1316
1317
1318
    } else {
      temp << std::scientific << op->value;
      if (op->dtype.bits() == 32)
        temp << 'f';
1319
    }
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
    p->MarkConst(temp.str());
    os << temp.str();
    break;
  }
  case 16: {
    os << "half_t" << '(';
    FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
    PrintConst(const_f32.get(), os, p);
    os << ')';
    break;
  }
  default:
    LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
1333
1334
1335
  }
}

1336
1337
void CodeGenTileLangHIP::VisitExpr_(const FloatImmNode *op,
                                    std::ostream &os) { // NOLINT(*)
1338
1339
1340
  PrintConst(op, os, this);
}

1341
1342
1343
void CodeGenTileLangHIP::HandleVolatileLoads(const std::string &value,
                                             const BufferLoadNode *op,
                                             std::ostream &os) {
1344
1345
1346
  // Cast away volatile qualifier for fp16 types. That is, only loads and
  // stores are volatile. The loaded objects are not marked as volatile.
  //
1347
1348
  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) &&
      IsVolatile(op->buffer->data.get())) {
1349
1350
1351
1352
1353
1354
1355
1356
    os << "(";
    PrintType(op->dtype, os);
    os << ")(" << value << ")";
  } else {
    os << value;
  }
}

1357
1358
1359
void CodeGenTileLangHIP::PrintVecElemLoadExpr(DataType t, int i,
                                              const std::string &value,
                                              std::ostream &os) {
1360
1361
1362
1363
1364
1365
  ICHECK_GT(t.lanes(), 1);
  if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
    if (!(t.lanes() == 2 || t.lanes() == 3)) {
      if (i != 0) {
        os << "|";
      }
1366
1367
      os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8
         << "))";
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
      return;
    }
  }

  if (t.is_float16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_half2(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (t.is_bfloat16()) {
    if (i == 0) {
      os << "make_";
      PrintType(t, os);
      os << '(';
    }
    if (i % 2 == 0) {
      os << "__pack_bfloat162(" << value;
    } else {
      os << "," << value << ")";
      if (i != t.lanes() - 1) {
        os << ",";
      } else {
        os << ")";
      }
    }
    return;
  }

  if (i == 0) {
    os << "make_";
    PrintType(t, os);
    os << "(";
  }
  os << value;
  if (i != t.lanes() - 1) {
    os << ",";
  } else {
    os << ")";
  }
  return;
}

1424
void CodeGenTileLangHIP::AddFunction(const PrimFunc &f) {
1425
1426
1427
1428
1429
1430
  // clear previous generated state.
  this->InitFuncState(f);
  // reserve keywords
  ReserveKeywordsAsUnique();

  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
1431
  ICHECK(global_symbol.has_value())
1432
1433
      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
1434
1435
1436
1437
1438
1439
  std::unordered_set<const VarNode *> non_restrict;
  if (auto opt =
          f->GetAttr<ffi::Array<tir::Var>>(tl::attr::kNonRestrictParams)) {
    for (const tir::Var &v : opt.value())
      non_restrict.insert(v.get());
  }
1440
1441
1442
1443
1444
1445
1446
1447

  this->PrintFuncPrefix(stream);
  CodeGenC::PrintType(f->ret_type, stream);
  this->PrintExtraAttrs(f, stream);
  this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";
  for (size_t i = 0; i < f->params.size(); ++i) {
    tir::Var v = f->params[i];
    std::string vid = AllocVarID(v.get());
1448
1449
    if (i != 0)
      stream << ", ";
1450
1451
    if (v.dtype().is_handle()) {
      // work around for grid constant parameters.
1452
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
        if (ptr->storage_scope == "grid_constant") {
          stream << "__grid_constant__ const ";
          CodeGenC::PrintType(ptr->element_type, stream);
          stream << ' ' << vid;
          continue;
        }
      }

      auto it = alloc_storage_scope_.find(v.get());
      if (it != alloc_storage_scope_.end()) {
        PrintStorageScope(it->second, stream);
      }

      CodeGenC::PrintType(GetType(v), stream);
1467
1468
      if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
        if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
1469
1470
1471
1472
          RegisterHandleType(v.get(), prim->dtype);
        }
      }

1473
      if (no_alias && !non_restrict.count(v.get())) {
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
        PrintRestrict(v, stream);
      }
    } else {
      CodeGenC::PrintType(GetType(v), stream);
    }
    stream << ' ' << vid;
  }
  stream << ") {\n";
  this->PreFunctionBody(f);
  int func_scope = this->BeginScope();
  this->PrintStmt(f->body);
  this->EndScope(func_scope);
  this->PrintIndent();
  this->stream << "}\n\n";
}

1490
1491
} // namespace codegen
} // namespace tvm