#include "codegen_cutedsl.h" #include "runtime/cuda/cuda_module.h" #include "runtime/pack_args.h" #include namespace tvm { namespace codegen { static std::unordered_map ExtractFuncInfo(const IRModule &mod) { std::unordered_map fmap; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; for (size_t i = 0; i < f->params.size(); ++i) { if (f->params[i]->dtype.is_handle()) { auto ptr = f->params[i]->type_annotation.as(); if (ptr && ptr->storage_scope == "grid_constant") { info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1)); continue; } } info.arg_types.push_back(f->params[i].dtype()); } if (auto opt = f->GetAttr>( tir::attr::kKernelLaunchParams)) { for (const auto &tag : opt.value()) { info.launch_param_tags.push_back(tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; } ffi::Module BuildTileLangCuTeDSLWithoutCompile(IRModule mod, Target target) { CodeGenTileLangCuTeDSL cg; for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenTileLangCuTeDSL: Can only take PrimFunc"; auto gvar = Downcast(kv.first); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch); cg.AddFunction(gvar, f); } std::string code = cg.Finish(); if (const auto f = ffi::Function::GetGlobal("tilelang_callback_cutedsl_postproc")) { code = (*f)(code, target).cast(); } return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.tilelang_cutedsl_without_compile", BuildTileLangCuTeDSLWithoutCompile); } } // namespace codegen } // namespace tvm