"git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "bae2333791d354756c73ea65e98589cd26e57c94"
Commit 8ccf6ea2 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Enhance GPU Kernel Launch with Environment Thread Creation (#178)

- Introduce `CreateEnvThread` function to generate environment threads for GPU kernel launches
- Modify `KernelLaunch` to use `CreateEnvThread` for block and thread indices
- Improve thread variable naming with shorter, more descriptive identifiers (bx, by, bz, tx, ty, tz)
- Ensure proper thread environment setup within PrimFunc context
parent 7bde63d5
...@@ -18,6 +18,21 @@ constexpr const char *tilelang_is_cpu_kernel_frame = ...@@ -18,6 +18,21 @@ constexpr const char *tilelang_is_cpu_kernel_frame =
using namespace script::ir_builder::tir; using namespace script::ir_builder::tir;
static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
using namespace tvm::tir;
using namespace tvm::script::ir_builder;
IterVar iter_var(Range{nullptr}, Var(name, dtype),
tvm::tir::IterVarType::kThreadIndex, thread_tag);
Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame =
IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
opt_frame.value()->env_threads.Set(var, iter_var);
} else {
LOG(FATAL) << "EnvThread can only be used inside a PrimFunc";
}
return var;
}
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) { static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name); Var var = Var(name);
...@@ -160,19 +175,34 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -160,19 +175,34 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
// Launch GPU Kernel // Launch GPU Kernel
ICHECK(grid_size.size() <= 3); ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0) if (grid_size.size() > 0)
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0])); n->frames.push_back(LaunchThread(
CreateEnvThread("bx", "blockIdx.x", grid_size[0].dtype()),
grid_size[0]));
if (grid_size.size() > 1) if (grid_size.size() > 1)
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1])); n->frames.push_back(LaunchThread(
CreateEnvThread("by", "blockIdx.y", grid_size[1].dtype()),
grid_size[1]));
if (grid_size.size() > 2) if (grid_size.size() > 2)
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2])); n->frames.push_back(LaunchThread(
CreateEnvThread("bz", "blockIdx.z", grid_size[2].dtype()),
grid_size[2]));
if (block_size.defined()) { if (block_size.defined()) {
ICHECK(block_size.size() <= 3); ICHECK(block_size.size() <= 3);
if (block_size.size() > 0) if (block_size.size() > 0) {
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0])); n->frames.push_back(LaunchThread(
if (block_size.size() > 1) CreateEnvThread("tx", "threadIdx.x", block_size[0].dtype()),
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1])); block_size[0]));
if (block_size.size() > 2) }
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2])); if (block_size.size() > 1) {
n->frames.push_back(LaunchThread(
CreateEnvThread("ty", "threadIdx.y", block_size[1].dtype()),
block_size[1]));
}
if (block_size.size() > 2) {
n->frames.push_back(LaunchThread(
CreateEnvThread("tz", "threadIdx.z", block_size[2].dtype()),
block_size[2]));
}
} else { } else {
n->frames.push_back(Block("")); n->frames.push_back(Block(""));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment