#if defined(__linux__) #include #include #endif #include #include #include "codegen_hip.h" #include "runtime/rocm/rocm_module.h" #include #ifndef kTVMGridConstant #define kTVMGridConstant 130 #endif namespace tvm { namespace codegen { static std::unordered_map ExtractFuncInfo(const IRModule &mod) { std::unordered_map fmap; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; for (size_t i = 0; i < f->params.size(); ++i) { if (f->params[i]->dtype.is_handle()) { auto ptr = f->params[i]->type_annotation.as(); if (ptr && ptr->storage_scope == "grid_constant") { info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1)); continue; } } info.arg_types.push_back(f->params[i].dtype()); } if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; } runtime::Module BuildTileLangHIP(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangHIP: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); cg.AddFunction(f); } std::string code = cg.Finish(); // Use the new FFI API to get registered functions using ffi::Function; if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) { code = (*f)(code, target).cast(); } std::string fmt = "ptx"; std::string ptx; if (auto f = Function::GetGlobal("tilelang_callback_hip_compile")) { ptx = (*f)(code, target).cast(); if (ptx[0] != '/') fmt = "hsaco"; } else { ICHECK(false) << "tilelang_callback_hip_compile is not set"; } return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); } runtime::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangHIP cg; cg.Init(output_ssa); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangHIP: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); cg.AddFunction(f); } std::string code = cg.Finish(); // Use the new FFI API to get registered functions using ffi::Function; if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) { code = (*f)(code, target).cast(); } return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code, std::string()); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.tilelang_hip", BuildTileLangHIP) .def("target.build.tilelang_hip_without_compile", BuildTileLangHIPWithoutCompile); }); } // namespace codegen } // namespace tvm