/*! * \file target/codegen.cc */ #include "codegen_cuda.h" #include #include #include #include #include #include #include #include #include "../op/builtin.h" #include "../op/bulk_copy.h" #include "arith/pattern_match.h" #include "target/source/ptx.h" namespace tvm { namespace codegen { static std::string GetFP8Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); std::string vec; if (type.is_scalar()) { vec = ""; } else if (lanes == 2) { vec = "_2"; } else if (lanes == 4) { vec = "_4"; } else if (lanes == 8) { vec = "_8"; } else if (lanes == 16) { vec = "_16"; } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " "for FP8"; } if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() || type.is_float8_e4m3()) { stream << "fp8_e4" << vec << "_t"; } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() || type.is_float8_e5m2()) { stream << "fp8_e5" << vec << "_t"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type; } return stream.str(); } std::string GetFP6Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); std::string vec; if (type.is_scalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; } else if (lanes == 4) { vec = "x4"; } else if (lanes == 8) { vec = "x8"; } else if (lanes == 16) { vec = "x16"; } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP6"; } stream << "__nv_fp6"; std::string suffix; if (type.code() == DataType::kFloat6_e2m3fn) { suffix = "_e2m3"; } else if (type.code() == DataType::kFloat6_e3m2fn) { suffix = "_e3m2"; } else { LOG(FATAL) << "Unsupported FP6 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); } std::string GetFP4Type(DataType type) { std::stringstream stream; int32_t lanes = type.lanes(); std::string vec; if (type.is_scalar()) { vec = ""; } else if (lanes == 2) { vec = "x2"; } else if (lanes == 4) { vec = "x4"; } else if (lanes == 8) { vec = "x8"; } else if (lanes == 16) { vec = "x16"; } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4) for FP4"; } stream << "__nv_fp4"; std::string suffix; if (type.code() == DataType::kFloat4_e2m1fn) { suffix = "_e2m1"; } else { LOG(FATAL) << "Unsupported FP4 type in CUDA codegen"; } stream << vec << suffix; return stream.str(); } CodeGenTileLangCUDA::CodeGenTileLangCUDA() { restrict_keyword_ = "__restrict__"; vid_global_barrier_state_ = name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state); vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect"); ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) { os << "extern \"C\" __global__ "; } class LaunchConfigExtractor : public tir::StmtVisitor { private: void VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { threadIdx_x_ext = op->value; } else if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { threadIdx_y_ext = op->value; } else if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { threadIdx_z_ext = op->value; } } StmtVisitor::VisitStmt_(op); } public: PrimExpr threadIdx_x_ext = Integer(1); PrimExpr threadIdx_y_ext = Integer(1); PrimExpr threadIdx_z_ext = Integer(1); }; void CodeGenTileLangCUDA::PrintExtraAttrs(const PrimFunc &f) { LaunchConfigExtractor extractor; extractor(f->body); arith::Analyzer analyzer; PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * extractor.threadIdx_z_ext); if (const IntImmNode *const threadIdx_ext_int = threadIdx_ext.as()) { if (threadIdx_ext_int->value == 1) { // unable to extract the number of threads per block, hence directly // return return; } stream << " __launch_bounds__(" << threadIdx_ext_int->value << ", 1)"; } } std::string CodeGenTileLangCUDA::Finish() { if (need_mma_h_) { decl_stream << "#include \n"; } if (enable_fp8_) { decl_stream << "#include \n"; } if (need_math_constants_h_) { decl_stream << "#include \n"; } if (need_cooperative_groups_) { decl_stream << "#include \n"; } decl_stream << "#include \n"; if (enable_sparse_gemm_) { decl_stream << "#include \n"; } decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#ifdef ENABLE_BF16\n"; decl_stream << "#include \n"; decl_stream << "#endif\n"; if (need_global_barrier_) { decl_stream << "__device__ unsigned " << vid_global_barrier_state_ << " = 0;\n"; } decl_stream << "\n"; return CodeGenC::Finish(); } void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) { if (op->kind == tir::ForKind::kUnrolled) { PrintIndent(); stream << "#pragma unroll\n"; } std::string extent = PrintExpr(arith::Analyzer().Simplify(op->extent + op->min)); PrintIndent(); std::string vid = AllocVarID(op->loop_var.get()); std::string start = PrintExpr(op->min); stream << "for ("; PrintType(op->loop_var.dtype(), stream); stream << ' ' << vid << " = " << start << "; " << vid << " < " << extent << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); PrintIndent(); stream << "}\n"; } void CodeGenTileLangCUDA::BindThreadIndex(const IterVar &iv) { ICHECK(!var_idmap_.count(iv->var.get())); var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { ICHECK(t.is_scalar()) << "do not yet support vector types"; os << "void*"; return; } if (t.is_void()) { os << "void"; return; } if (t == tl::cuTensorMapType()) { os << "CUtensorMap"; return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { case 16: enable_fp16_ = true; if (t.is_scalar()) { os << "half_t"; } else if (lanes <= 8) { // Emit CUDA code to access fp16 vector elements. // // half4 is stored as uint2 // // h4.x is emitted as *(half2*)(&(u2.x)).x // h4.y is emitted as *(half2*)(&(u2.x)).y // h4.z is emitted as *(half2*)(&(u2.y)).x // h4.w is emitted as *(half2*)(&(u2.y)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; os << "uint" << lanes / 2; } else { fail = true; } break; case 32: if (lanes <= 4) { os << "float"; } else if (lanes <= 8) { // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. // // float8 is stored as ulonglong4 // // f8.v1 is emitted as *(float2*)(&(ul4.x)).x // f8.v2 is emitted as *(float2*)(&(ul4.x)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; os << "ulonglong" << lanes / 2; } else { fail = true; } break; case 64: os << "double"; break; default: fail = true; break; } if (!fail && (t.is_scalar() || t.bits() == 16)) return; if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } } else if (t.is_bfloat16()) { enable_bf16_ = true; if (t.is_scalar()) { os << "bfloat16_t"; } else if (lanes <= 8) { ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; os << "uint" << lanes / 2; } else { fail = true; } if (!fail) return; } else if (t.is_float8()) { enable_fp8_ = true; os << GetFP8Type(t); return; } else if (t.is_float6()) { enable_fp6_ = true; if (t.lanes() <= 4) { os << GetFP6Type(t); } else { fail = true; } return; } else if (t.is_float4()) { enable_fp4_ = true; if (t.lanes() <= 4) { os << GetFP4Type(t); } else { fail = true; } return; } else if (t == DataType::Bool()) { os << "bool"; return; } else if (t.is_vector_bool()) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. int n = t.lanes(); if (n <= 4) { os << "ushort" << n; return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << "u"; } switch (t.bits()) { case 1: { if (t.is_scalar()) { os << "int"; return; } else if (t.lanes() == 8) { os << "int8_t"; return; } else if (t.lanes() == 16) { os << "int16_t"; return; } else if (t.lanes() == 32) { os << "int"; return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } } case 4: { if (t.is_scalar()) { os << "int"; return; } else if (t.lanes() == 4) { os << "int16_t"; return; } else if (t.lanes() == 8) { // directly 8 4-bit int in integer. os << "int"; return; } else if (t.lanes() == 16) { os << "int2"; return; } else if (t.lanes() == 32) { os << "int4"; return; } else if (t.lanes() == 64) { os << "int8"; return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } } case 8: { if (t.lanes() == 4) { // directly 4 8 bit int in integer. enable_int8_ = true; // We use int for int8x4 instead of char4 because using char4 is // likely to produce extra instructions to pack four int8 elements // into 32-bit data. os << "int"; return; } else if (t.lanes() == 8) { enable_int8_ = true; os << "int2"; return; } else if (t.lanes() == 16) { enable_int8_ = true; os << "int4"; return; } else if (!t.is_uint() && t.is_scalar()) { os << "signed char"; break; } else { os << "char"; break; } } case 16: { if (t.is_scalar()) { os << "short"; } else if (t.lanes() <= 4) { os << "short" << lanes; } else if (t.lanes() <= 8) { // Emit CUDA code to access int16 vector elements. // // short4 is stored as int2 // // s4.x is emitted as *(short2*)(&(i2.x)).x // s4.y is emitted as *(short2*)(&(i2.x)).y // s4.z is emitted as *(short2*)(&(i2.y)).x // s4.w is emitted as *(short2*)(&(i2.y)).y // ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; os << "int" << t.lanes() / 2; } else { fail = true; } if (!fail) { return; } break; } case 32: { if (t.is_scalar()) { os << "int"; } else if (t.lanes() <= 4) { os << "int" << t.lanes(); } else if (t.lanes() <= 8) { // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. // // int8 is stored as longlong4 // // i8.v1 is emitted as *(int2*)(&(l4.x)).x // i8.v2 is emitted as *(int2*)(&(l4.x)).y // ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4"; os << "longlong" << lanes / 2; } else { fail = true; } if (!fail) { return; } break; } case 64: { if (t.is_scalar()) { os << "int64_t"; } else if (t.lanes() == 2) { os << "longlong2"; } else if (t.lanes() == 3) { os << "longlong3"; } else if (t.lanes() == 4) { os << "longlong4"; } return; } default: fail = true; break; } if (!fail && lanes == 1) { return; } if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } } LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } void CodeGenTileLangCUDA::PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) { // NOLINT(*) // Declare the result. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(t, stream); stream << ' ' << sret << ";\n"; int ssa_scope = BeginScope(); { // Unpack into individual ops. std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { std::ostringstream value_temp; if (isalpha(op[0])) { value_temp << op << "("; PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); value_temp << ", "; PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); value_temp << ")"; } else { value_temp << "("; PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp); value_temp << op; PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp); value_temp << ")"; } PrintVecElemStore(sret, t, i, value_temp.str()); } } EndScope(ssa_scope); os << sret; } void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, int i, std::ostream &os) { // NOLINT(*) if (t.is_scalar()) { os << vec; return; } static const char access[] = {'x', 'y', 'z', 'w'}; ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { std::string type_name = t.is_int() ? "char" : "unsigned char"; if (t.lanes() == 2 || t.lanes() == 3) { os << vec << "." << access[i % t.lanes()]; } else { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; } } else if (t.is_float16()) { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.is_bfloat16()) { os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { if (t.is_int()) { type_name = "short"; } else if (t.is_uint()) { type_name = "ushort"; } } else if (t.bits() == 32) { if (t.is_int()) { type_name = "int"; } else if (t.is_uint()) { type_name = "uint"; } else if (t.is_float()) { type_name = "float"; } } ICHECK(!type_name.empty()); os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else { os << vec << "." << access[i]; } } void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, int i, const std::string &value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.lanes() == 2 || t.lanes() == 3) { stream << vec << '.' << access[i % t.lanes()] << "=" << "(" << value << ");\n"; } else { std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); stream << ac << "="; // Do not read the first undef lane. if (i != 0) { stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; } stream << "(" << value << " << " << i % 4 * 8 << ");\n"; } } else if (t.is_float16()) { stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } else if (t.is_bfloat16()) { stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } else if (t.lanes() > 4 && t.lanes() <= 8) { std::string type_name; if (t.bits() == 16) { if (t.is_int()) { type_name = "short"; } else if (t.is_uint()) { type_name = "ushort"; } } else if (t.bits() == 32) { if (t.is_int()) { type_name = "int"; } else if (t.is_uint()) { type_name = "uint"; } else if (t.is_float()) { type_name = "float"; } } ICHECK(!type_name.empty()); stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " << value << ";\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } } void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) { auto args = op->args; const std::string &sync = args[0].as()->value; if (sync == "warp") { // DO nothing. } else if (sync == "shared" || sync == "shared.dyn") { this->PrintIndent(); if (args.size() == 1) { this->stream << "__syncthreads();\n"; } else if (args.size() == 2) { auto barrier_id = args[1].as()->value; this->stream << "tl::__sync_thread_partial<" << barrier_id << ">();\n"; } else if (args.size() == 3) { auto barrier_id = args[1].as()->value; auto thread_count = args[2].as()->value; this->stream << "tl::__sync_thread_partial<" << barrier_id << ", " << thread_count << ">();\n"; } else { LOG(FATAL) << "Invalid number of arguments for storage sync: " << args.size(); } } else if (sync == "global") { if (!need_global_barrier_) { need_global_barrier_ = true; } // global synchronizer std::string is_load = PrintExpr(op->args[1]); std::string num_blocks = PrintExpr(op->args[2]); this->PrintIndent(); // In theory only threadfence is needed // but we observed problems with only threadfence this->stream << "__threadfence_system();\n"; this->PrintIndent(); this->stream << "if (" << is_load << ") {\n"; int wb = this->BeginScope(); this->PrintIndent(); this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; this->PrintIndent(); std::string ptr = name_supply_->FreshName("pf"); this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; this->PrintIndent(); this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; this->PrintIndent(); this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; this->EndScope(wb); this->PrintIndent(); this->stream << "}\n"; this->PrintIndent(); this->stream << "__syncthreads();\n"; } } void CodeGenTileLangCUDA::PrintStorageScope(const std::string &scope, std::ostream &os) { // NOLINT(*) ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " "all global arrays as input instead"; if (scope == "shared" || scope == "shared.barrier") { os << "__shared__ "; } else if (scope == "shared.dyn") { os << "extern __shared__ __align__(1024) "; } } std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, DataType target) { if (from == target) return value; std::ostringstream os; os << "(("; this->PrintType(target, os); os << ")"; if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { os << "("; if (target.is_uint()) { os << "u"; } os << "int)"; } os << value << ")"; return os.str(); } void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { DataType from_ty = op->value.dtype(); DataType target_ty = op->dtype; ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(target_ty, stream); stream << ' ' << sret << ";\n"; std::string src = SSAGetID(PrintExpr(op->value), from_ty); // Handle bfloat16 special cases with supported ops bool used_bf16_op = false; if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) { std::ostringstream func_name; if (from_ty.is_bfloat16()) func_name << "bf16"; else if (from_ty.is_float()) func_name << "float"; if (from_ty.lanes() > 1) func_name << from_ty.lanes(); func_name << "2"; if (target_ty.is_bfloat16()) func_name << "bf16"; else if (target_ty.is_float()) func_name << "float"; else if (target_ty == DataType::Int(16)) func_name << "int16"; if (target_ty.lanes() > 1) func_name << target_ty.lanes(); auto fname = func_name.str(); if (bf16_supported_ops_.count(fname)) { used_bf16_op = true; stream << "#ifdef ENABLE_BF16\n"; PrintIndent(); stream << "reinterpret_cast<"; if (target_ty.is_bfloat16()) stream << "__nv_bfloat16"; else PrintType(target_ty.element_of(), stream); if (target_ty.lanes() > 1) stream << target_ty.lanes(); stream << " &>(" << sret << ") = fastertransformer::" << fname << "(reinterpret_cast<"; if (from_ty.is_bfloat16()) stream << "__nv_bfloat16"; else PrintType(from_ty.element_of(), stream); if (from_ty.lanes() > 1) stream << from_ty.lanes(); stream << " const &>(" << src << "));\n"; stream << "#else\n"; } } // Fallback: elementwise cast for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; val << "("; PrintType(target_ty.element_of(), val); val << ")("; PrintVecElemLoad(src, from_ty, i, val); val << ")"; PrintVecElemStore(sret, target_ty, i, val.str()); } if (used_bf16_op) { stream << "#endif\n"; } os << sret; } void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array &args, bool skip_first_arg, std::ostream &os) { // NOLINT(*) DataType ret_dtype = GetRuntimeDataType(ret_type); if (ret_dtype.is_fixed_length_vector()) { // // Emit an unsupported vector call // // v = intrin_f((float4*)A[0], (float4*)B[0]) // // as // // float4 __ret; // { // float4 __arg0 = ((float4*)A)[0]; // float4 __arg1 = ((float4*)B)[0]; // __ret.x = intrin_f(__arg0.x, __arg1.x); // __ret.y = intrin_f(__arg0.y, __arg1.y); // __ret.z = intrin_f(__arg0.z, __arg1.z); // __ret.w = intrin_f(__arg0.w, __arg1.w); // } // v = __ret; // // Declare the result vector. std::string sret = name_supply_->FreshName("_"); this->PrintIndent(); this->PrintType(ret_dtype, stream); stream << ' ' << sret << ";\n"; { // Load arguments. std::vector sargs; size_t arg_begin = static_cast(skip_first_arg); for (size_t i = arg_begin; i < args.size(); ++i) { std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); sargs.push_back(std::move(val)); } // Emit a scalar call for each lane. for (int i = 0; i < ret_dtype.lanes(); ++i) { std::ostringstream scall; scall << global_symbol << "("; for (size_t j = 0; j < sargs.size(); ++j) { if (j > 0) scall << ", "; PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); } scall << ")"; PrintVecElemStore(sret, ret_dtype, i, scall.str()); } } os << sret; } else { CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); } } // Print a reference expression to a buffer. std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, const BufferNode *buffer, PrimExpr index) { const VarNode *buffer_var = buffer->data.get(); std::ostringstream os; std::string vid = GetVarID(buffer_var); std::string scope; if (alloc_storage_scope_.count(buffer_var)) { scope = alloc_storage_scope_.at(buffer_var); } // bool is_vol = IsVolatile(buffer_var); // always false for tl cutlass backend. bool is_vol = false; auto ptr_cast = [this, is_vol, scope](DataType pointed_to) { std::ostringstream ptr_os; ptr_os << "("; if (is_vol) { ptr_os << "volatile "; } if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, ptr_os); } PrintType(pointed_to, ptr_os); ptr_os << "*)"; return ptr_os.str(); }; DataType buffer_element_dtype = buffer->dtype; std::string buffer_str = vid; if (!HandleTypeMatch(buffer_var, buffer_element_dtype) || is_vol) { std::stringstream temp; temp << "(" << ptr_cast(buffer_element_dtype) << vid << ")"; buffer_str = temp.str(); } if (scope.empty()) { scope = GetPtrStorageScope(buffer->data); } if (scope == "local.var") { os << vid; return os.str(); } std::string index_str = PrintExpr(index); if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { // This is a special case, because CodegenCUDA::PrintType() // returns "int" for bool and for 4-bit integers. In most cases, // we divide by the number of lanes to determine the index. // However, the backing type for scalar int4 and scalar bool is // int32. Therefore, we need to divide by the ratio of their // sizes in that case. int div_factor = (t.lanes() == 1) ? (32 / t.bits()) : t.lanes(); os << "*(" << "(" << ptr_cast(t) << vid << ")" << " + " << index_str << " / " << div_factor << ")"; } else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; } else { os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; } return os.str(); } void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { auto print_extern_call_stmt = [&](std::string name, size_t start = 0, size_t end = 0) { // Cache context into a private ss, otherwise the let node may generate // within the function call arguments. std::ostringstream ss; for (size_t i = start; i < op->args.size() - end; i++) { if (i > start) ss << ", "; ss << this->PrintExpr(op->args[i]); } this->PrintIndent(); this->stream << name << "("; this->stream << ss.str(); this->stream << ");\n"; }; auto print_mbarrier_obj = [&](PrimExpr barrier_id) { std::ostringstream ss; if (barrier_id.as()) { // incase the barrier_id is an integer, we need to print the barrier_id as // an integer ss << mbarrier_name_ << "[" << barrier_id << "]"; } else { // otherwise may be a T.get_mbarrier() call or BufferLoad Node // we need to print the barrier_id as a string ss << this->PrintExpr(barrier_id); } return ss.str(); }; if (op->op.same_as(builtin::ptx_cp_async())) { std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); // use size of argument list to indicate whether or not to use predicated // cp.async if (op->args.size() == 5) { this->PrintIndent(); this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" << dst_offset << ", " << src << "+" << src_offset << ");\n"; } else { std::string condition = this->PrintExpr(op->args[5]); this->PrintIndent(); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst << "+" << dst_offset << ", " << src << "+" << src_offset << ", " << condition << ");\n"; } } else if (op->op.same_as(builtin::ptx_commit_group())) { print_extern_call_stmt("tl::cp_async_commit"); } else if (op->op.same_as(builtin::ptx_wait_group())) { int n = Downcast(op->args[0])->value; std::string func_name = "tl::cp_async_wait<" + std::to_string(n) + ">"; print_extern_call_stmt(func_name, 1); } else if (op->op.same_as(builtin::create_barriers())) { this->PrintIndent(); int barrier_count = Downcast(op->args[0])->value; auto mbarrier_storage_name = mbarrier_name_ + "_mem"; this->stream << "__shared__ uint64_t " << mbarrier_storage_name << "[" << barrier_count << "];\n"; this->PrintIndent(); this->stream << "auto " << mbarrier_name_ << " = reinterpret_cast<" << mbarrier_dtype_ << "*>(" << mbarrier_storage_name << ");\n"; } else if (op->op.same_as(tl::get_mbarrier())) { ICHECK_EQ(op->args.size(), 1); std::string barrier_id = this->PrintExpr(op->args[0]); os << mbarrier_name_ + "[" + barrier_id + "]"; } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { if (op->args.size() == 1) { this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); this->stream << mbarrier_obj << ".arrive();\n"; } else if (op->args.size() == 3) { this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto cta_id = this->PrintExpr(op->args[1]); auto pred = this->PrintExpr(op->args[2]); this->stream << mbarrier_obj << ".arrive(" << cta_id << ", " << pred << ");\n"; } else { LOG(FATAL) << "Invalid parameter for tl::arrive_barrier " << op->args.size(); } } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { ICHECK_EQ(op->args.size(), 2); this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto arrive_count = this->PrintExpr(op->args[1]); this->stream << mbarrier_obj << ".init(" << arrive_count << ");\n"; } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { if (op->args.size() == 2) { this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); this->stream << mbarrier_obj << ".arrive_and_expect_tx(" << transaction_bytes << ");\n"; } else if (op->args.size() == 4) { this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); auto cta_id = this->PrintExpr(op->args[2]); auto pred = this->PrintExpr(op->args[3]); this->stream << mbarrier_obj << ".arrive_and_expect_tx(" << transaction_bytes << ", " << cta_id << ", " << pred << ");\n"; } else { LOG(FATAL) << "Invalid parameter for tl::arrive_barrier_expect_tx " << op->args.size(); } } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); } else if (op->op.same_as(tl::mbarrier_expect_tx())) { ICHECK_EQ(op->args.size(), 2); this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto transaction_bytes = this->PrintExpr(op->args[1]); this->stream << mbarrier_obj << ".expect_transaction(" << transaction_bytes << ");\n"; } else if (op->op.same_as(tl::mbarrier_wait_parity())) { ICHECK_EQ(op->args.size(), 2); this->PrintIndent(); auto mbarrier_obj = print_mbarrier_obj(op->args[0]); auto phase = this->PrintExpr(op->args[1]); this->stream << mbarrier_obj << ".wait(" << phase << ");\n"; } else if (op->op.same_as(tl::sync_thread_partial())) { print_extern_call_stmt("cutlass::arch::NamedBarrier::sync"); } else if (op->op.same_as(tl::no_set_max_nreg())) { return; } else if (op->op.same_as(tl::tma_load())) { std::ostringstream ss; ICHECK_GE(op->args.size(), 2); auto eviction_policy = this->eviction_policy_names_ [op->args[op->args.size() - 1].as()->value]; // Simplify the code by using the default eviction policy if (eviction_policy != "EVICT_NORMAL") { ss << "tl::tma_load("; } else { ss << "tl::tma_load("; } auto desc = op->args[0]; ss << this->PrintExpr(desc) << ", "; ss << print_mbarrier_obj(op->args[1]) << ", "; for (size_t i = 2; i < op->args.size() - 1; i++) { if (i > 2) ss << ", "; ss << this->PrintExpr(op->args[i]); } ss << ");\n"; this->PrintIndent(); this->stream << ss.str(); } else if (op->op.same_as(tl::tma_load_im2col())) { std::stringstream ss; auto eviction_policy = this->eviction_policy_names_ [op->args[op->args.size() - 1].as()->value]; if (eviction_policy != "EVICT_NORMAL") { ss << "tl::tma_load_im2col"; } else { ss << "tl::tma_load_im2col"; } print_extern_call_stmt(ss.str(), 0, 1); } else if (op->op.same_as(tl::tma_store())) { std::stringstream ss; auto eviction_policy = this->eviction_policy_names_ [op->args[op->args.size() - 1].as()->value]; if (eviction_policy != "EVICT_NORMAL") { ss << "tl::tma_store"; } else { ss << "tl::tma_store"; } print_extern_call_stmt(ss.str(), 0, 1); } else if (op->op.same_as(tl::ptx_ldmatirx())) { int trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; std::string func_name = "tl::ptx_ldmatrix_x" + std::to_string(num); if (trans == 1) func_name += "_trans"; print_extern_call_stmt(func_name, 2); } else if (op->op.same_as(tl::ptx_stmatrix())) { int trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; std::string func_name = "tl::ptx_stmatrix_x" + std::to_string(num); if (trans == 1) func_name += "_trans"; print_extern_call_stmt(func_name, 2); } else if (op->op.same_as(tl::fence_proxy_async())) { print_extern_call_stmt("tl::fence_proxy_async"); } else if (op->op.same_as(tl::tma_store_arrive())) { print_extern_call_stmt("tl::tma_store_arrive"); } else if (op->op.same_as(tl::tma_store_wait())) { print_extern_call_stmt("tl::tma_store_wait<0>"); } else if (op->op.same_as(tl::set_max_nreg())) { this->PrintIndent(); int nreg = Downcast(op->args[0])->value; int is_inc = Downcast(op->args[1])->value; std::string func_name = is_inc ? "tl::warpgroup_reg_alloc" : "tl::warpgroup_reg_dealloc"; this->stream << func_name << "<" << std::to_string(nreg) << ">();\n"; } else if (op->op.same_as(tl::wait_wgmma())) { this->PrintIndent(); int num_mma = Downcast(op->args[0])->value; this->stream << "tl::wait_wgmma<" << std::to_string(num_mma) << ">();\n"; } else if (op->op.same_as(tl::pack_b16())) { os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " << this->PrintExpr(op->args[1]) << ")"; } else if (op->op.same_as(tl::sync_grid())) { this->need_cooperative_groups_ = true; this->PrintIndent(); this->stream << "cooperative_groups::grid_group grid = " "cooperative_groups::this_grid();\n"; this->PrintIndent(); this->stream << "grid.sync();\n"; } else if (op->op.same_as(tl::loop_break())) { this->PrintIndent(); this->stream << "break;\n"; } else if (op->op.same_as(builtin::tvm_fill_fragment())) { need_mma_h_ = true; ICHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; this->PrintExpr(op->args[0], os); os << "["; this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[5], os); os << ")"; } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { need_mma_h_ = true; ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; this->PrintExpr(op->args[0], os); os << "["; this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[5], os); os << ", "; this->PrintExpr(op->args[6], os); os << ")"; } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { need_mma_h_ = true; ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; this->PrintExpr(op->args[5], os); os << ", "; this->PrintExpr(op->args[0], os); os << "["; this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[6], os); if (const StringImmNode *str = op->args[7].as()) { os << ", nvcuda::wmma::mem_" << str->value; } else { LOG(FATAL) << "Invalid parameters"; } os << ")"; } else if (op->op.same_as(builtin::tvm_mma_sync())) { need_mma_h_ = true; ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; for (int i = 0; i < 4; ++i) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->op.same_as(builtin::tvm_bmma_sync())) { need_mma_h_ = true; ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; for (int i = 0; i < 4; ++i) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->op.same_as(builtin::ptx_mma())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col // arg 3: A precision: fp16, fp64, ... // arg 4: B precision: fp16, fp64, ... // arg 5: C precision: fp32, fp64, ... // arg 6: A multiplicand // arg 7: A multiplicand index // arg 8: B multiplicand // arg 9: B multiplicand index // arg 10: C accumulator // arg 11: C accumulator index // arg 12: saturate // arg 13: (optional) 1-bit operator (xor or and) ICHECK(op->args.size() == 13U || op->args.size() == 14U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; std::string A_dtype = Downcast(op->args[3])->value; std::string B_dtype = Downcast(op->args[4])->value; std::string C_dtype = Downcast(op->args[5])->value; std::string a_ref = this->PrintExpr(op->args[6]); std::string a_bias = this->PrintExpr(op->args[7]); std::string b_ref = this->PrintExpr(op->args[8]); std::string b_bias = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_bias = this->PrintExpr(op->args[11]); bool saturate = Downcast(op->args[12])->value; std::string bit_op = op->args.size() > 13 ? Downcast(op->args[13])->value : ""; std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_mma_sp())) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col // arg 3: A precision: fp16, fp32, ... // arg 4: B precision: fp16, fp32, ... // arg 5: C precision: fp16, fp32, ... // arg 6: A multiplicand pointer // arg 7: A multiplicand index // arg 8: B multiplicand pointer // arg 9: B multiplicand index // arg 10: C accumulator pointer // arg 11: C accumulator index // arg 12: metadata // arg 13: metadata index // arg 14: sparse_selector // arg 15: saturate ICHECK_EQ(op->args.size(), 16U); std::string shape = Downcast(op->args[0])->value; std::string A_layout = Downcast(op->args[1])->value; std::string B_layout = Downcast(op->args[2])->value; std::string A_dtype = Downcast(op->args[3])->value; std::string B_dtype = Downcast(op->args[4])->value; std::string C_dtype = Downcast(op->args[5])->value; std::string a_ref = this->PrintExpr(op->args[6]); std::string a_offset = this->PrintExpr(op->args[7]); std::string b_ref = this->PrintExpr(op->args[8]); std::string b_offset = this->PrintExpr(op->args[9]); std::string c_ref = this->PrintExpr(op->args[10]); std::string c_offset = this->PrintExpr(op->args[11]); std::string metadata = this->PrintExpr(op->args[12]); std::string metadata_offset = this->PrintExpr(op->args[13]); std::string sparse_selector = this->PrintExpr(op->args[14]); bool saturate = Downcast(op->args[15])->value; std::string asm_code = PrintMMAAssembly( shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; } else if (op->op.same_as(builtin::ptx_ldmatrix())) { // arg 0: whether the matrix is loaded in column major format or not. // arg 1: number of matrices to load. // arg 2: The data type in the matrix, .b16 is the only accepted data type. // arg 3: pointer to local buffer. // arg 4: The offset of the element to store in the local buffer. // arg 5: pointer to the shared memory buffer to load. // arg 6: The offset of the start element of the row to load in shared // memory. ICHECK_EQ(op->args.size(), 7U); bool trans = Downcast(op->args[0])->value; int num = Downcast(op->args[1])->value; std::string type = Downcast(op->args[2])->value; std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_elem_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); if (trans && op->dtype.bits() == 8) { // Since ldmatrix assumes that a matrix element is 16 bit, it cannot // properly transpose an int8 matrix. std::string smem_stride = this->PrintExpr(op->args[6]); ICHECK(num == 4); os << "for (int i = 0; i < 16; ++i) {\n"; os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n"; os << "}\n"; } else { std::string smem_elem_offset = this->PrintExpr(op->args[6]); need_cast_smem_ptr_to_int_ = true; this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, smem_ptr, smem_elem_offset); } } else if (op->op.same_as(builtin::mma_store())) { int m = Downcast(op->args[0])->value; int n = Downcast(op->args[1])->value; std::string dst = this->PrintExpr(op->args[2]); std::string src = this->PrintExpr(op->args[3]); std::string src_offset = this->PrintExpr(op->args[4]); PrimExpr stride = op->args[5]; ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; // Each thread in a warp holds a certain number of elements of an MMA // output. For example, if we compute a 16x16 tile using MMA, each thread // holds 8 elements in its registers. So conceptually, a warp memory is // organized as a 32x8 block. A map from a 16x16 tile to a 32x8 block of // memory is specified by the index map below. // To store the 32x8 output back to a 16x16 tile in shared or global memory, // we invert this map to determine the output location for each 8 element. const auto index_map_func = ffi::Function::GetGlobal( "tir.index_map.shared_16x16_to_mma_32x8_layout"); IndexMap index_map; if (!index_map_func) { Var i, j; // The index map is defined as follows: index_map = IndexMap( {i, j}, {4 * FloorMod(i, 8) + FloorDiv(FloorMod(j, 8), 2), 4 * FloorDiv(j, 8) + FloorDiv(i, 8) * 2 + FloorMod(j, 2)}); } else { index_map = IndexMap::FromFunc(2, *index_map_func); } arith::Analyzer analyzer; auto inverse_index_map = index_map.Inverse({Range(0, m), Range(0, n)}, &analyzer); auto indices_16x16 = inverse_index_map->final_indices; // "//" and "%" in the index map are translated to FloorDiv/Mod, but the // plain Div/Mod are fine. FloorDiv/Mod are supposed to be lowered before // they reach codegen, so manually replace them to the plain ones here. class LowerFloorDivMod : public ExprMutator { public: PrimExpr VisitExpr_(const FloorDivNode *op) { return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); } PrimExpr VisitExpr_(const FloorModNode *op) { return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); } }; auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; if (op->dtype.bits() == 16) { os << "for (int local_id = 0; local_id < 8; local_id+=2) {\n"; os << "*((uint *)&" << dst << "[" + this->PrintExpr(dst_ind) + "])" << " = " << "*((uint *)&" << src << "[" << src_offset << " + local_id]);\n"; os << "}\n"; } else { os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; os << dst << "[" + this->PrintExpr(dst_ind) + "]" << " = " << src << "[" << src_offset << " + local_id];\n"; os << "}\n"; } } else if (op->op.same_as(builtin::mma_fill())) { std::string num_elem = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[1]); std::string dst_offset = this->PrintExpr(op->args[2]); os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; } else if (op->op.same_as(builtin::ptx_cp_async())) { std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); need_cast_smem_ptr_to_int_ = true; // use size of argument list to indicate whether or not to use predicated // cp.async if (op->args.size() == 5) { this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); } else { this->stream << PrintPredicatedCpAsyncAssembly( dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); } } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { need_cast_smem_ptr_to_int_ = true; std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); int barrier_id = Downcast(op->args[5])->value; CHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { int n = Downcast(op->args[0])->value; this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; CHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string thread_count = this->PrintExpr(op->args[1]); this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; CHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintArriveBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; CHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; std::string byte_count = this->PrintExpr(op->args[1]); this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); } else if (op->op.same_as(builtin::ptx_wait_barrier())) { need_cast_smem_ptr_to_int_ = true; int barrier_id = Downcast(op->args[0])->value; CHECK(barrier_id < barrier_count_); std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintWaitBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_ldg32())) { /* asm volatile ( "{.reg .pred p;\n" " setp.ne.b32 p, %2, 0;\n" // " @p ld.global.nc.f32 %0, [%1];}\n"t " @p ld.global.nc.L2::128B.f32 %0, [%1];}\n" : "=f"(reg) : "l"(addr), "r"((int)guard) ); */ // get local std::string reg = this->PrintExpr(op->args[0]); // get guard std::string guard = this->PrintExpr(op->args[1]); const BufferLoadNode *addr_buffer = op->args[2].as(); std::string global_addr = this->PrintExpr(addr_buffer->indices[0]); std::string global_buffer = this->PrintExpr(addr_buffer->buffer->data); std::string local_addr = this->PrintExpr(op->args[3]); this->stream << "asm volatile (\n"; this->stream << "\"{.reg .pred p;\\n\"\n"; this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n"; this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n"; stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n"; stream << ");\n"; } else if (op->op.same_as(builtin::reinterpret())) { DataType tgt_dtype = op->dtype; DataType src_dtype = op->args[0]->dtype; PrimExpr value = op->args[0]; // Handle float4_e2m1fn reinterpret if (!src_dtype.is_float4_e2m1fn() && !tgt_dtype.is_float4_e2m1fn()) { return CodeGenC::VisitExpr_(op, os); } if (src_dtype == tgt_dtype || tgt_dtype.lanes() * tgt_dtype.bits() == src_dtype.lanes() * src_dtype.bits()) { return CodeGenC::VisitExpr_(op, os); } CHECK_EQ(tgt_dtype.lanes(), src_dtype.lanes()) << "E2M1 float4 reinterpret expects source and target to have the same " "number of lanes. " << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; CHECK_EQ(tgt_dtype.bytes(), src_dtype.bytes()) << "E2M1 float4 reinterpret expects source and target to have the same " "number of bytes. " << "Source dtype: " << src_dtype << ", Target dtype: " << tgt_dtype; int lanes = tgt_dtype.lanes(); int ssa_scope = BeginScope(); if (lanes == 1) { // The case of lane=1 is same as the normal reinterpret, // except that we allow the src and dst dtype to have different number of // bits. std::string rhs = SSAGetID(PrintExpr(value), src_dtype); os << "(*("; this->PrintType(tgt_dtype, os); os << " *)(&(" << rhs << ")))"; } else if (lanes == 2) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint16, and then extract bits of two fp4 // numbers, and finally reinterpret the result as fp4x2. value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}); tir::Var temp_var("temp_var", DataType::UInt(16)); value = tir::Let(temp_var, value, tir::Cast(DataType::UInt(8), (temp_var & IntImm(DataType::UInt(16), 0xF)) | ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); } else { value = tir::Cast( DataType::UInt(16), tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value})); tir::Var temp_var("temp_var", DataType::UInt(16)); value = tir::Let(temp_var, value, (temp_var & IntImm(DataType::UInt(16), 0xF)) | ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); } os << PrintExpr( tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); } else if (lanes == 4) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint32, and then extract bits of four fp4 // numbers, and finally reinterpret the result as fp4x4. value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}); tir::Var temp_var("temp_var", DataType::UInt(32)); value = tir::Let( temp_var, value, tir::Cast( DataType::UInt(16), (temp_var & IntImm(DataType::UInt(32), 0xF)) | ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); } else { value = tir::Cast(DataType::UInt(32), tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value})); tir::Var temp_var("temp_var", DataType::UInt(32)); value = tir::Let( temp_var, value, (temp_var & IntImm(DataType::UInt(32), 0xF)) | ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); } os << PrintExpr( tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); } else { LOG(FATAL) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; } EndScope(ssa_scope); } else if (op->op.same_as(builtin::thread_return())) { os << "return"; } else if (op->op.same_as(tl::tl_gemm())) { ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments , but got " << op->args.size(); auto op_instance = Downcast(op->args[0]); this->PrintCallExtern(GetType(GetRef(op)), op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_gemm_sp())) { ICHECK(op->args.size() == 5) << "tl_gemm_sp expects 5 arguments , but got " << op->args.size(); auto op_instance = Downcast(op->args[0]); enable_sparse_gemm_ = true; this->PrintCallExtern(GetType(GetRef(op)), op_instance->value, op->args, true, os); } else if (op->op.same_as(tl::tl_shuffle_elect())) { os << "tl::tl_shuffle_elect<" << PrintExpr(op->args[0]) << ">()"; } else { CodeGenC::VisitExpr_(op, os); } } void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) { if (op->attr_key == tir::attr::fragment_shape) { const VarNode *buffer = op->node.as(); const StringImmNode *shape_str = op->value.as(); fragment_shapes[buffer] = shape_str->value; } else if (op->attr_key == tir::attr::fragment_layout) { const VarNode *buffer = op->node.as(); const StringImmNode *layout_str = op->value.as(); fragment_layouts[buffer] = layout_str->value; } else if (op->attr_key == tir::attr::async_commit_queue_scope) { const IntImmNode *queue_id = op->value.as(); ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); this->VisitExpr(commit_group, this->stream); return; } else if (op->attr_key == tir::attr::async_wait_queue_scope) { auto wait_attrs = GetAsyncWaitAttributes(op); auto queue_id = wait_attrs.first.as(); ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); this->VisitExpr(wait_group, this->stream); auto inner = op->body.as(); ICHECK(inner); this->VisitStmt(inner->body); return; } else if (op->attr_key == "threadblock_swizzle_pattern") { this->PrintIndent(); const StringImmNode *pattern = op->value.as(); ICHECK(pattern); this->stream << "const dim3 blockIdx = " << pattern->value << "();\n"; this->VisitStmt(op->body); return; } CodeGenC::VisitStmt_(op); } void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ICHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); std::string scope = GetPtrStorageScope(op->buffer_var); const VarNode *buffer = op->buffer_var.as(); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || op->dtype == DataType::BFloat(16)) << "Matrix_a and matrix_b only support half or char or unsigned char " << "or uint4 or int4 or int1 type for now"; } else { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } if (scope == "shared.dyn") { stream << ' ' << vid << "[];\n"; } else { size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now, but get " << constant_size << " for " << op->buffer_var->name_hint; if (scope.find("wmma.") == 0) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); } if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) && scope == "shared") { constant_size = constant_size / (32 / op->dtype.bits()); } if (scope == "shared") { stream << ' ' << vid << '[' << constant_size << "];\n"; } else if (scope == "shared.barrier") { auto v_id_mem = vid + "_mem"; stream << ' ' << v_id_mem << "[" << constant_size << "];\n"; PrintIndent(); stream << "auto " << vid << " = reinterpret_cast<" << mbarrier_dtype_ << "*>(" << v_id_mem << ");\n"; } else if (scope == "local") { stream << ' ' << vid << '[' << constant_size << "];\n"; } else if (scope == "local.var") { stream << ' ' << vid << " = " << PrintExpr(tir::make_const(op->dtype, 0)) << ";\n"; } else { ICHECK(false) << "Unsupported scope: " << scope; } } RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { if (is_const_int(op->value)) return; const CallNode *call = op->value.as(); if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) { PrintIndent(); stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n"; PrintIndent(); stream << "if (threadIdx.x == 0) {\n"; PrintIndent(); stream << " " << vid_global_barrier_expect_ << " = 0;\n"; PrintIndent(); stream << "}\n"; } else { CodeGenC::VisitStmt_(op); } } void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { int lanes = static_cast(Downcast(op->lanes)->value); CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef(op) << " with " << lanes << " lanes is not allowed."; os << "(make_"; PrintType(op->dtype, os); os << "("; for (int i = 0; i < lanes; i++) { os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i << ")"; if (i != lanes - 1) os << ", "; } os << "))"; } void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, std::ostream &os) { // NOLINT(*) ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported."; ICHECK(!op->predicate.defined()) << "Predicated buffer load is not supported."; DataType value_dtype = op->dtype; PrimExpr index = op->indices[0]; Var buffer_var = op->buffer->data; DataType element_dtype = op->buffer->dtype; int lanes = op->dtype.lanes(); // delcare type. if (value_dtype.lanes() == element_dtype.lanes()) { std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); HandleVolatileLoads(ref, op, os); } else { bool can_vector_load = false; arith::PVar base; if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { const RampNode *ramp = index.as(); ICHECK(ramp); can_vector_load = true; // arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k // if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() // == 0) { // can_vector_load = true; // } } if (value_dtype.is_float4_e2m1fn() && lanes != 1) { // A float4_e2m1fn element has 4 bits, which is an incomplete byte. // So we cannot vector load it. can_vector_load = false; } if (can_vector_load) { std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); std::string vid = GetVarID(buffer_var.get()); DataType elem_type = op->dtype.element_of(); for (int i = 0; i < lanes; ++i) { std::ostringstream value_temp; if (!HandleTypeMatch(buffer_var.get(), elem_type)) { value_temp << "(("; if (buffer_var.get()->dtype.is_handle()) { auto it = alloc_storage_scope_.find(buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); } } PrintType(elem_type, value_temp); value_temp << "*)" << vid << ')'; } else { value_temp << vid; } value_temp << '['; PrintVecElemLoad(sindex, index.dtype(), i, value_temp); value_temp << ']'; PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); } os << svalue_expr.str(); } } } void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, std::ostream &os) { // NOLINT(*) int lanes = static_cast(Downcast(op->lanes)->value); if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { // make_int8x4 const int64_t *p = as_const_int(op->value); ICHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; if (op->dtype.is_uint()) { os << "(uint)" << v; } else { os << "(int)" << v; } return; } if (op->dtype.is_float16()) { std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_half2(" << v << ", " << v << ")"; } os << ')'; return; } if (op->dtype.is_bfloat16()) { std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes / 2; ++i) { if (i != 0) os << ", "; os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; } os << ')'; return; } if (op->dtype.is_float() && op->dtype.bits() == 32 && op->dtype.lanes() == 8) { std::string v = PrintExpr(op->value); os << "make_ulonglong4("; for (int i = 0; i < 4; ++i) { if (i != 0) os << ", "; os << "*(unsigned long long*)&make_float2(" << v << ", " << v << ")"; } os << ')'; return; } if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { bool fail = false; const int64_t *p = as_const_int(op->value); ICHECK(p); int64_t v = *p & 0xF; if (lanes == 4) { v = (v << 12) | (v << 8) | (v << 4) | v; if (op->dtype.is_uint()) { os << "(uint16_t)" << v; } else { os << "(int16_t)" << v; } } else { v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; if (lanes == 8) { if (op->dtype.is_uint()) { os << "(uint)" << v; } else { os << "(int)" << v; } } else if (lanes == 16 || lanes == 32) { os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes / 8; ++i) { if (i != 0) os << ", "; if (op->dtype.is_uint()) { os << "(uint)" << v; } else { os << "(int)" << v; } } os << ')'; } else { fail = true; } } if (!fail) { return; } } std::string v = PrintExpr(op->value); os << "make_"; PrintType(op->dtype, os); os << '('; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } os << ')'; } inline void PrintConst(const FloatImmNode *op, std::ostream &os, CodeGenTileLangCUDA *p) { // NOLINT(*) // Type code is kBFloat if (op->dtype.is_bfloat16()) { os << "bfloat16_t"; os << '(' << std::scientific << op->value << 'f' << ')'; return; } // Type code is kFloat8_e5m2 or kE4M4Float if (op->dtype.is_float8() || op->dtype.is_float4()) { p->PrintType(op->dtype, os); os << '(' << std::scientific << op->value << 'f' << ')'; return; } // Type code is kFloat switch (op->dtype.bits()) { case 64: case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); p->need_math_constants_h_ = true; } else if (std::isnan(op->value)) { temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); p->need_math_constants_h_ = true; } else { temp << std::scientific << op->value; if (op->dtype.bits() == 32) temp << 'f'; } p->MarkConst(temp.str()); os << temp.str(); break; } case 16: { os << "half_t" << '('; FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); PrintConst(const_f32.get(), os, p); os << ')'; break; } default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } void CodeGenTileLangCUDA::VisitExpr_(const FloatImmNode *op, std::ostream &os) { // NOLINT(*) PrintConst(op, os, this); } void CodeGenTileLangCUDA::PrintWmmaScope(const std::string &scope, DataType t, const VarNode *variable, std::ostream &os) { std::stringstream type; PrintType(t, type); ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { type.str(std::string()); if (t.is_int()) { if (t.bits() == 4) { type << "nvcuda::wmma::experimental::precision::s4"; } else if (t.bits() == 1) { type << "nvcuda::wmma::experimental::precision::b1"; } else { LOG(FATAL) << "Unhandled integer type for wmma fragment!"; } } else if (t.is_uint()) { if (t.bits() == 4) { type << "nvcuda::wmma::experimental::precision::u4"; } else { LOG(FATAL) << "Unhandled integer type for wmma fragment!"; } } } if (scope == "wmma.matrix_a") { std::string layout_str = fragment_layouts[variable]; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { std::string layout_str = fragment_layouts[variable]; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { os << "nvcuda::wmma::fragment"; } } int32_t CodeGenTileLangCUDA::GetWmmaFragmentSize(const std::string &scope, const VarNode *variable, int32_t size) { ICHECK(fragment_shapes.count(variable)) << "Cannot find shape of the wmma fragment " << variable->name_hint; std::string shape_str = fragment_shapes.at(variable); std::pair dim = GetWmmaFragmentDimSize(shape_str, scope); if (dim.first * dim.second != 0) return size / dim.first / dim.second; else return 0; } void CodeGenTileLangCUDA::HandleVolatileLoads(const std::string &value, const BufferLoadNode *op, std::ostream &os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer->data.get())) { os << "("; PrintType(op->dtype, os); os << ")(" << value << ")"; } else { os << value; } } void CodeGenTileLangCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string &value, std::ostream &os) { ICHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (!(t.lanes() == 2 || t.lanes() == 3)) { if (i != 0) { os << "|"; } os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; return; } } if (t.is_float16()) { if (i == 0) { os << "make_"; PrintType(t, os); os << '('; } if (i % 2 == 0) { os << "__pack_half2(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { os << ","; } else { os << ")"; } } return; } if (t.is_bfloat16()) { if (i == 0) { os << "make_"; PrintType(t, os); os << '('; } if (i % 2 == 0) { os << "__pack_bfloat162(" << value; } else { os << "," << value << ")"; if (i != t.lanes() - 1) { os << ","; } else { os << ")"; } } return; } if (i == 0) { os << "make_"; PrintType(t, os); os << "("; } os << value; if (i != t.lanes() - 1) { os << ","; } else { os << ")"; } return; } void CodeGenTileLangCUDA::PrintFunctionSignature(const String &function_name, const PrimFunc &func, std::ostream &os) { PrintFuncPrefix(os); CodeGenC::PrintType(func->ret_type, os); CodeGenC::PrintExtraAttrs(func, os); bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); os << " " << function_name << "("; for (size_t i = 0; i < func->params.size(); ++i) { tir::Var v = func->params[i]; std::string vid = AllocVarID(v.get()); if (i > 0) { os << ", "; } if (v.dtype().is_handle()) { // work around for grid constant parameters. if (auto *ptr = v->type_annotation.as()) { if (ptr->storage_scope == "grid_constant") { os << "__grid_constant__ const "; CodeGenC::PrintType(ptr->element_type, os); os << ' ' << vid; continue; } } auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } CodeGenC::PrintType(GetType(v), os); if (auto *ptr = v->type_annotation.as()) { if (auto *prim = ptr->element_type.as()) { RegisterHandleType(v.get(), prim->dtype); } } if (no_alias) { PrintRestrict(v, os); } } else { CodeGenC::PrintType(GetType(v), os); } os << ' ' << vid; } os << ")"; // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). for (const auto ¶m : func->params) { if (auto *ptr = param->type_annotation.as()) { if (auto *prim = ptr->element_type.as()) { RegisterHandleType(param.get(), prim->dtype); } } } } void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, const PrimFunc &f) { // If the function has already been forward-declared, this is a // no-op. CodeGenC::DeclareFunction(gvar, f); // clear previous generated state. this->InitFuncState(f); // reserve keywords ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(stream); CodeGenC::PrintType(f->ret_type, stream); this->PrintExtraAttrs(f); this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; std::string vid = AllocVarID(v.get()); if (i != 0) stream << ", "; if (v.dtype().is_handle()) { // work around for grid constant parameters. if (auto *ptr = v->type_annotation.as()) { if (ptr->storage_scope == "grid_constant") { stream << "__grid_constant__ const "; CodeGenC::PrintType(ptr->element_type, stream); stream << ' ' << vid; continue; } } auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); } CodeGenC::PrintType(GetType(v), stream); if (auto *ptr = v->type_annotation.as()) { if (auto *prim = ptr->element_type.as()) { RegisterHandleType(v.get(), prim->dtype); } } if (no_alias) { PrintRestrict(v, stream); } } else { CodeGenC::PrintType(GetType(v), stream); } stream << ' ' << vid; } stream << ") {\n"; this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); this->EndScope(func_scope); this->PrintIndent(); this->stream << "}\n\n"; } } // namespace codegen } // namespace tvm