/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file lower_device_kernel_launch.cc * \brief Split device function from host. */ #include #include #include #include #include #include #include #include "runtime/thread_storage_scope.h" #include "tir/transforms/ir_utils.h" namespace tvm { namespace tl { using namespace tir; namespace { struct KernelInfo { // The device on which the PrimFunc runs Target target; // The externally visible symbol which may refer to the PrimFunc // when launching a device kernel. String global_symbol; // The parameters accepted by the PrimFunc. Used to rewrite // `launch_args` to be in terms of the calling scope. Array params; // The launch parameters that should annotate the PrimFunc, if the // kernel is ever called from the host. Array launch_params; // Additional arguments which must be provided to the host-side // PackedFunc. These may be in terms of the function's parameters // (e.g. a function that computes the average of `N` elements, and // which must be launched with `N` CUDA threads). Array launch_args; // The extent of each thread Map thread_extent; // The amount of dynamic shared memory used Optional dyn_shmem_size{std::nullopt}; }; /*! * \brief Visitor class to collect device-side program information. */ class DeviceInfoCollector : public StmtVisitor { public: static KernelInfo Collect(const GlobalVar &gvar, const PrimFunc &func) { DeviceInfoCollector collector; collector.info_.target = func->GetAttr(tvm::attr::kTarget).value().WithoutHost(); collector.info_.params = func->params; collector(func->body); // The dynamic shared memory is required to be the last of the // kernel launch parameters if (collector.dyn_shmem_size) { collector.info_.launch_params.push_back( tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); } collector.info_.global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol) .value_or(gvar->name_hint); collector.info_.launch_args = collector.info_.launch_params.Map( [&](const auto ¶m) { return collector.GetArgument(param); }); collector.info_.dyn_shmem_size = collector.dyn_shmem_size; collector.info_.thread_extent = collector.thread_extent; return collector.info_; } private: PrimExpr GetArgument(const String &launch_param) const { if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { CHECK(dyn_shmem_size.defined()) << "Compute kernel requires launch parameter \"" << launch_param << "\", but PrimFunc did not contain Allocate node with shared " "dynamic scope."; return dyn_shmem_size.value(); } auto extent = thread_extent.Get(launch_param); CHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param << "\", but PrimFunc does not contain AttrStmt \"" << tir::attr::thread_extent << "\" defining this thread extent"; return extent.value(); } void VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); ICHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. if (!defined_thread.count(iv.get())) { defined_thread.insert(iv.get()); info_.launch_params.push_back(iv->thread_tag); thread_extent.Set(iv->thread_tag, op->value); } } StmtVisitor::VisitStmt_(op); } void VisitStmt_(const AllocateNode *op) final { auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { ICHECK(!dyn_shmem_size.defined()) << "Only one dynamic shared memory allocation is allowed."; ICHECK_GT(op->extents.size(), 0); PrimExpr dyn_size = Integer(1); for (const auto &extent : op->extents) { dyn_size *= extent; } dyn_size *= op->dtype.bytes() * op->dtype.lanes(); dyn_shmem_size = dyn_size; } StmtVisitor::VisitStmt_(op); } // The collected results KernelInfo info_; // recording what thread axis have been visited. std::unordered_set defined_thread; // The extent of each thread Map thread_extent; // The amount of dynamic shared memory used Optional dyn_shmem_size{std::nullopt}; }; class ReturnRemover : public StmtExprMutator { public: static Stmt Apply(const Stmt &stmt) { ReturnRemover mutator; return mutator(stmt); } private: using Parent = StmtExprMutator; Stmt VisitStmt_(const EvaluateNode *op) override { if (auto *call = op->value.as()) { if (call->op.same_as(builtin::ret())) { ICHECK_EQ(call->args.size(), 1); auto as_int = call->args[0].as(); ICHECK(as_int && as_int->value == 0) << "Device kernel may only contain successful return, T.ret(0)"; return Evaluate(0); } } return Parent::VisitStmt_(op); } PrimExpr VisitExpr_(const CallNode *op) override { if (op->op.same_as(builtin::ret())) { LOG(FATAL) << "Call to builtin::ret() should only appear within an " "Evaluate node"; } return Parent::VisitExpr_(op); } }; } // namespace class DeviceKernelMutator : public StmtExprMutator { public: using Parent = StmtExprMutator; explicit DeviceKernelMutator( std::unordered_map device_info_map) : device_info_map_(std::move(device_info_map)) {} PrimFunc RewriteKernelLaunchSite(const GlobalVar &gvar, PrimFunc func) { ICHECK(!current_target_.defined()); auto it = device_info_map_.find(gvar.get()); ICHECK(it != device_info_map_.end()); current_target_ = it->second.target; auto body = VisitStmt(func->body); if (!body.same_as(func->body)) { func.CopyOnWrite()->body = body; } current_target_ = std::nullopt; return func; } PrimFunc UpdateKernelAttributes(const GlobalVar &gvar, PrimFunc func) const { bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); bool is_call_extern = extern_function_call_.count(gvar.get()); CHECK(!is_kernel_launch || !is_call_extern) << "Function " << gvar << " has multiple callees, " << "and would need to be lowered into a call_extern at some call " "sites, " << "and a device kernel launch at others. " << "This case is not yet supported."; if (is_kernel_launch || is_call_extern) { func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, Bool(true)); } if (is_kernel_launch) { const auto &info = device_info_map_.at(gvar.get()); // Kernel launches provide an int32 error code to the caller, // but do not accept any return type from the callee. { auto write_ptr = func.CopyOnWrite(); write_ptr->ret_type = VoidType(); write_ptr->body = ReturnRemover::Apply(write_ptr->body); } func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, Integer(tvm::CallingConv::kDeviceKernelLaunch)}, {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); } // @lei: workaround as we may require c host codegen, so we need to set the // global symbol for cpu backend. func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); const auto &info = device_info_map_.at(gvar.get()); const auto &thread_extent = info.thread_extent; func = WithAttr(std::move(func), "thread_extent", thread_extent); if (info.dyn_shmem_size.defined()) { func = WithAttr(std::move(func), "dyn_shared_memory_buf", info.dyn_shmem_size.value()); } return func; } private: PrimExpr VisitExpr_(const CallNode *op) override { auto node = Downcast(Parent::VisitExpr_(op)); auto *gvar = op->op.as(); if (!gvar) return std::move(node); auto it = device_info_map_.find(gvar); ICHECK(it != device_info_map_.end()) << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " << gvar->name_hint << " did not appear within the IRModule"; const KernelInfo &dev_info = it->second; auto caller_target = current_target_.value(); auto callee_target = dev_info.target; bool same_target = caller_target->str() == callee_target->str(); if (same_target) { // Calls within the same target may be handled at codegen time // as internal subroutine calls. return std::move(node); } bool same_device_type = caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); if (same_device_type) { // Calls to another target using the same device (e.g. LLVM // calling a custom TIRToRuntime target) do not require a kernel // launch, but need to be replaced with call_extern. extern_function_call_.insert(gvar); Array args; args.push_back(StringImm(gvar->name_hint)); for (const auto &arg : node->args) { args.push_back(arg); } return Call(node->dtype, builtin::call_extern(), args); } ICHECK(dev_info.launch_params.defined()) << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " << dev_info.target << ", but subroutine " << gvar->name_hint << " did not have the tir::attr::kKernelLaunchParams attribute " << "required for cross-target kernel launch"; // Collected kernel information may be in terms of the callee's // arguments, but we need expressions for them in terms of the // caller's parameters. The param_map allows substitution of // parameter values into the thread extents, to generate // expressions that are valid within the caller. Map param_map = [&]() { Map param_map; CHECK_EQ(node->args.size(), dev_info.params.size()) << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() << " arguments as input, but is called using " << node->args.size() << " arguments"; for (size_t i = 0; i < node->args.size(); i++) { param_map.Set(dev_info.params[i], node->args[i]); } return param_map; }(); device_kernel_launch_.insert(gvar); Array call_args; call_args.push_back(StringImm(dev_info.global_symbol)); for (PrimExpr arg : node->args) { call_args.push_back(arg); } for (const auto &launch_arg : dev_info.launch_args) { call_args.push_back(Substitute(launch_arg, param_map)); } auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; return Call(dtype, builtin::tvm_call_packed(), call_args); } Optional current_target_; std::unordered_map device_info_map_; std::unordered_set device_kernel_launch_; std::unordered_set extern_function_call_; }; namespace transform { tvm::transform::Pass LowerDeviceKernelLaunch() { auto pass_func = [](IRModule mod, const tir::transform::PassContext &ctx) -> IRModule { auto mutator = [&mod]() { std::unordered_map device_info_map; for (const auto &[gvar, base_func] : mod->functions) { if (auto prim_func = base_func.as()) { device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value()); } } return DeviceKernelMutator(std::move(device_info_map)); }(); { IRModule updates; for (const auto &[gvar, base_func] : mod->functions) { if (auto *ptr = base_func.as()) { auto prim_func = mutator.RewriteKernelLaunchSite(gvar, GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } } } if (!updates->functions.empty()) { mod.CopyOnWrite()->Update(updates); } } { IRModule updates; for (const auto &[gvar, base_func] : mod->functions) { if (auto *ptr = base_func.as()) { auto prim_func = mutator.UpdateKernelAttributes(gvar, GetRef(ptr)); if (!prim_func.same_as(base_func)) { updates->Add(gvar, prim_func); } } } if (!updates->functions.empty()) { mod.CopyOnWrite()->Update(updates); } } return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "tl.LowerDeviceKernelLaunch", {}); } TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); }); } // namespace transform } // namespace tl } // namespace tvm