#include "codegen_cuda.h" #include "runtime/cuda/cuda_module.h" #include "runtime/pack_args.h" #include namespace tvm { namespace codegen { static std::unordered_map ExtractFuncInfo(const IRModule &mod) { std::unordered_map fmap; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; for (size_t i = 0; i < f->params.size(); ++i) { if (f->params[i]->dtype.is_handle()) { auto ptr = f->params[i]->type_annotation.as(); if (ptr && ptr->storage_scope == "grid_constant") { info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1)); continue; } } info.arg_types.push_back(f->params[i].dtype()); } if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; } runtime::Module BuildTileLangCUDA(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangCUDA cg; cg.Init(output_ssa); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; auto gvar = Downcast(kv.first); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); cg.AddFunction(gvar, f); } std::string code = cg.Finish(); if (const auto f = ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) { code = (*f)(code, target).cast(); } std::string fmt = "ptx"; std::string ptx; if (const auto f = ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) { ptx = (*f)(code, target).cast(); if (ptx[0] != '/') fmt = "cubin"; } else { ICHECK(0); } return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } runtime::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) { bool output_ssa = false; CodeGenTileLangCUDA cg; cg.Init(output_ssa); for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangCUDA: Can only take PrimFunc"; auto gvar = Downcast(kv.first); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); cg.AddFunction(gvar, f); } std::string code = cg.Finish(); if (const auto f = ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) { code = (*f)(code, target).cast(); } return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.tilelang_cuda", BuildTileLangCUDA) .def("target.build.tilelang_cuda_without_compile", BuildTileLangCUDAWithoutCompile); }); } // namespace codegen } // namespace tvm