/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file codegen_c_host.cc */ #include "codegen_c_host.h" #include #include #include #include #include #include #include #include #include // For escaping strings embedded into generated C sources #include "support/str_escape.h" namespace tvm { namespace tl { CodeGenCHost::CodeGenCHost() { module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl, std::string target_str, const std::unordered_set &devices) { emit_asserts_ = emit_asserts; emit_fwd_func_decl_ = emit_fwd_func_decl; declared_globals_.clear(); decl_stream << "// tilelang target: " << target_str << "\n"; decl_stream << "#define TVM_EXPORTS\n"; decl_stream << "#include \"tvm/runtime/base.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; decl_stream << "#include \"tvm/ffi/c_api.h\"\n"; decl_stream << "#include \n"; // snprintf for richer assert messages with actual values decl_stream << "#include \n"; decl_stream << "#include \n"; CodeGenCHost::InitGlobalContext(); tvm::codegen::CodeGenC::Init(output_ssa); } void CodeGenCHost::InitGlobalContext() { decl_stream << "void* " << tvm::ffi::symbol::tvm_ffi_library_ctx << " = NULL;\n"; } void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &func) { return AddFunction(gvar, func, /*emit_fwd_func_decl=*/false); } void CodeGenCHost::AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &func, bool emit_fwd_func_decl) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); if (global_symbol) { function_names_.push_back(global_symbol.value()); } emit_fwd_func_decl_ = emit_fwd_func_decl; tvm::codegen::CodeGenC::AddFunction(gvar, func); if (func->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc) && !has_main_func_) { ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol " "attribute, " << "but function " << gvar << " only has attributes " << func->attrs; function_names_.push_back(tvm::ffi::symbol::tvm_ffi_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); PrintType(func->ret_type, stream); stream << " " << tvm::ffi::symbol::tvm_ffi_main << "(void* self, void* args,int num_args, void* result) {\n"; stream << " return " << static_cast(global_symbol.value()) << "(self, args, num_args, result);\n"; stream << "}\n"; has_main_func_ = true; } } void CodeGenCHost::GenerateForwardFunctionDeclarations( tvm::ffi::String global_symbol, const tvm::ffi::Array &arg_types, const tvm::Type &ret_type) { if (!emit_fwd_func_decl_) { return; } for (auto &func_already_defined : GetFunctionNames()) { if (global_symbol == func_already_defined) { return; } } this->PrintFuncPrefix(fwd_decl_stream); this->PrintType(ret_type, fwd_decl_stream); fwd_decl_stream << " " << global_symbol << "("; for (size_t i = 0; i < arg_types.size(); ++i) { if (i > 0) { fwd_decl_stream << ", "; } tvm::codegen::CodeGenSourceBase::PrintType(arg_types[i], fwd_decl_stream); } fwd_decl_stream << ");\n"; } void CodeGenCHost::PrintFuncPrefix(std::ostream &os) { // NOLINT(*) os << "#ifdef __cplusplus\n" << "extern \"C\"\n" << "#endif\n"; } void CodeGenCHost::PrintType(tvm::DataType t, std::ostream &os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { ICHECK_EQ(lanes, 1) << "does not support vector types"; os << "void*"; return; } if (t.is_void()) { os << "void"; return; } if (t == tvm::DataType::Bool()) { os << "bool"; return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { case 16: os << "half"; break; case 32: os << "float"; break; case 64: os << "double"; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { os << lanes; return; } } if (t.is_bfloat16()) { os << "__bf16"; return; } if (t.is_int() || t.is_uint()) { if (t.is_uint()) { os << 'u'; } switch (t.bits()) { case 8: os << "int8_t"; break; case 16: os << "int16_t"; break; case 32: os << "int32_t"; break; case 64: os << "int64_t"; break; case 1: os << "int32_t"; break; default: fail = true; break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { os << lanes; return; } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } void CodeGenCHost::VisitExpr_(const tvm::tir::BroadcastNode *op, std::ostream &os) { // NOLINT(*) std::string v = PrintExpr(op->value); int lanes = op->dtype.lanes(); os << "(("; PrintType(op->dtype, os); os << ")("; for (int i = 0; i < lanes; ++i) { if (i != 0) os << ", "; os << v; } os << "))"; } void CodeGenCHost::PrintGetFuncFromBackend( const std::string &func_name, const std::string &packed_func_name) { this->PrintIndent(); this->stream << "if (" << packed_func_name << " == NULL) {\n"; int packed_func_if_scope = this->BeginScope(); this->PrintIndent(); this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\"" << ", &" << packed_func_name << ") != 0) {\n"; int get_func_env_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; this->EndScope(get_func_env_scope); this->PrintIndent(); this->stream << "}\n"; this->EndScope(packed_func_if_scope); this->PrintIndent(); this->stream << "}\n"; } void CodeGenCHost::PrintCallPacked(const tvm::tir::CallNode *op) { using namespace tvm::tir; const StringImmNode *func_name = op->args[0].as(); ICHECK(func_name != nullptr) << "tvm_call_[c]packed_lowered expects first argument as function name"; int64_t begin = op->args[2].as()->value; int64_t end = op->args[3].as()->value; int64_t num_args = end - begin; ICHECK_GE(num_args, 0); std::string packed_func_name; if (op->op.same_as(builtin::tvm_call_packed_lowered())) { packed_func_name = GetPackedName(op); this->PrintGetFuncFromBackend(func_name->value, packed_func_name); } else { // directly use the original symbol ICHECK(op->op.same_as(builtin::tvm_call_cpacked_lowered())); packed_func_name = tvm::ffi::symbol::tvm_ffi_symbol_prefix + func_name->value; } std::string args_stack = PrintExpr(op->args[1]); this->PrintIndent(); std::string result = name_supply_->FreshName("result"); this->stream << "TVMFFIAny " << result << ";\n"; this->PrintIndent(); // must make sure type_index is set to none this->stream << result << ".type_index = kTVMFFINone;\n"; this->PrintIndent(); this->stream << result << ".zero_padding = 0;\n"; this->PrintIndent(); this->stream << result << ".v_int64 = 0;\n"; this->PrintIndent(); if (op->op.same_as(builtin::tvm_call_packed_lowered())) { this->stream << "if (TVMFFIFunctionCall(" << packed_func_name << ", "; } else { this->stream << "if (" << packed_func_name << "(NULL, "; } this->stream << "(TVMFFIAny*) " << args_stack << ", " << num_args << ", " << "&" << result << ") != 0) {\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; this->EndScope(func_call_scope); this->PrintIndent(); this->stream << "}\n"; } std::string CodeGenCHost::GetPackedName(const tvm::tir::CallNode *op) { using namespace tvm::tir; const StringImmNode *s = op->args[0].as(); ICHECK(s != nullptr) << "tvm_call_packed_lowered expects first argument as function name"; std::string func_name = s->value; std::string packed_func_name = func_name + "_packed"; std::string unique_name; auto it = declared_globals_.find(packed_func_name); if (it != declared_globals_.end()) { unique_name = it->second; } else { unique_name = name_supply_->FreshName(packed_func_name); declared_globals_[packed_func_name] = unique_name; decl_stream << "static void* " << unique_name << " = NULL;\n"; } return unique_name; } void CodeGenCHost::VisitExpr_(const tvm::tir::CallNode *op, std::ostream &os) { // NOLINT(*) using namespace tvm::tir; if (op->op.same_as(builtin::tvm_stack_alloca())) { std::string stack_name = name_supply_->FreshName("stack"); const std::string &type = op->args[0].as()->value; const IntImmNode *num = op->args[1].as(); ICHECK(num != nullptr); static_assert(alignof(TVMFFIAny) % alignof(DLTensor) == 0, "invariant"); size_t unit = sizeof(TVMFFIAny); size_t size = 0; if (type == "shape") { size = (num->value * sizeof(ffi::Shape::index_type) + unit - 1) / unit; } else if (type == "tvm_ffi_any") { size = (num->value * sizeof(TVMFFIAny) + unit - 1) / unit; } else if (type == "array") { size = (num->value * sizeof(DLTensor) + unit - 1) / unit; } else { LOG(FATAL) << "Unknown stack alloca type " << type; } this->PrintIndent(); this->stream << "TVMFFIAny " << stack_name << "[" << size << "];\n"; os << stack_name; } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { this->PrintCallPacked(op); } else if (op->op.same_as(builtin::tvm_call_cpacked_lowered())) { this->PrintCallPacked(op); } else if (op->op.same_as(builtin::tvm_throw_last_error())) { this->PrintIndent(); this->stream << "return -1;\n"; } else { tvm::codegen::CodeGenC::VisitExpr_(op, os); } } void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); stream << "if (!(" << cond << ")) {\n"; int assert_if_scope = this->BeginScope(); { // Prepare the base error message const auto *msg_node = op->message.as(); ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; const std::string &raw_msg = msg_node->value; const std::string esc_msg = tvm::support::StrEscape( raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, /*escape_whitespace_special_chars=*/true); // If the assertion is an equality check, append the actual LHS/RHS values if (const auto *eq = op->condition.as()) { std::string lhs = PrintExpr(eq->a); std::string rhs = PrintExpr(eq->b); PrintIndent(); stream << "char __tvm_assert_msg_buf[512];\n"; PrintIndent(); stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, " "got: %lld\", \"" << esc_msg << "\", (long long)(" << lhs << "), (long long)(" << rhs << "));\n"; PrintIndent(); stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " "__tvm_assert_msg_buf);\n"; } else { PrintIndent(); stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg << "\");\n"; } } PrintIndent(); stream << "return -1;\n"; this->EndScope(assert_if_scope); PrintIndent(); stream << "}\n"; } this->PrintStmt(op->body); } void CodeGenCHost::VisitExpr_(const tvm::tir::MinNode *op, std::ostream &os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } void CodeGenCHost::VisitExpr_(const tvm::tir::MaxNode *op, std::ostream &os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } template inline void CodeGenCHost::PrintTernaryCondExpr(const T *op, const char *compare, std::ostream &os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); std::ostringstream temp_b; VisitExpr(op->b, temp_b); std::string b_id = SSAGetID(temp_b.str(), op->b.dtype()); os << "((" << a_id << ") " << compare << " (" << b_id << ") " << "? (" << a_id << ") : (" << b_id << "))"; } } // namespace tl } // namespace tvm namespace tvm { namespace tl { using tvm::codegen::CodeGenSourceBase; using tvm::codegen::CSourceModuleCreate; using tvm::ffi::Array; using tvm::ffi::Map; using tvm::ffi::Module; using tvm::ffi::String; // Build function that mirrors TVM's host C codegen, registered under a // TileLang-specific name. ::tvm::ffi::Module BuildTileLangCHost(::tvm::IRModule mod, ::tvm::Target target) { bool output_ssa = false; bool emit_asserts = true; bool emit_fwd_func_decl = true; std::unordered_set devices; if (mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( "device_contexts") != nullptr) { ::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String> device_contexts = mod->GetAttr<::tvm::ffi::Map<::tvm::GlobalVar, ::tvm::ffi::String>>( "device_contexts") .value(); for (auto const &context : device_contexts) { devices.insert(context.second.data()); } } CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment( target->GetAttr<::tvm::Integer>("constants-byte-alignment").value_or(16)); auto is_aot_executor_fn = [](::tvm::tir::PrimFunc const &func) -> bool { return func->GetAttr<::tvm::Bool>("runner_function", ::tvm::Bool(false)) .value(); }; std::vector> funcs; for (auto [gvar, base_func] : mod->functions) { ICHECK(base_func->IsInstance<::tvm::tir::PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc"; auto prim_func = ::tvm::Downcast<::tvm::tir::PrimFunc>(base_func); funcs.push_back({gvar, prim_func}); } auto sort_key = [&is_aot_executor_fn](const auto &kv) { return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; }; std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto &kv_a, const auto &kv_b) { return sort_key(kv_a) < sort_key(kv_b); }); for (const auto &[gvar, prim_func] : funcs) { cg.DeclareFunction(gvar, prim_func); } for (const auto &[gvar, prim_func] : funcs) { cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); } std::string code = cg.Finish(); return ::tvm::codegen::CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.tilelang_c", BuildTileLangCHost); } } // namespace tl } // namespace tvm