Unverified Commit 7913fb1d authored by Chaofan Lin's avatar Chaofan Lin Committed by GitHub
Browse files

[Bugfix] Fix dummy kernel compliation (#962)



* [Bugfix] Fix visit EvaluateNode in BufferGemmCollector

* address comment

* lint

* fix

* Add TileLang SplitHostDevice pass and tighten issue 830 test names

* lint fix

* enhance for kernel value unpacking.

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 6031416f
......@@ -115,7 +115,12 @@ public:
private:
void VisitStmt_(const EvaluateNode *op) {
auto call = Downcast<Call>(op->value);
const CallNode *call_node = op->value.as<CallNode>();
// Value of EvaluateNode may not be a call
if (!call_node) {
return;
}
auto call = Downcast<Call>(call_node);
if (call->op.same_as(Gemm::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/global_var_supply.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "tir/analysis/var_use_def_analysis.h"
namespace tvm {
namespace tl {
namespace tir = tvm::tir;
class HostDeviceSplitter : public tir::StmtMutator {
public:
explicit HostDeviceSplitter(IRModule *device_mod,
std::function<GlobalVar()> var_supply)
: device_mod_(device_mod), var_supply_(std::move(var_supply)) {}
tir::Stmt VisitStmt_(const tir::AttrStmtNode *op) final {
if (op->attr_key == tvm::attr::kTarget) {
found_device_region_ = true;
auto device_target = op->node.as<tvm::Target>().value().WithoutHost();
return SplitDeviceFunc(op->body, device_target);
}
return tir::StmtMutator::VisitStmt_(op);
}
tir::Stmt ForceSplit(tir::Stmt body, tvm::Target device_target) {
return SplitDeviceFunc(std::move(body), std::move(device_target));
}
bool found_device_region() const { return found_device_region_; }
private:
bool found_device_region_{false};
tir::Stmt SplitDeviceFunc(tir::Stmt body, tvm::Target device_target) {
auto [params, buffers_to_declare] =
[&]() -> std::tuple<Array<tir::Var>, Array<tir::Buffer>> {
tir::VarUseDefAnalyzer use_def(/*defined_vars=*/{},
/*visit_thread_extent=*/true);
use_def(body);
// Sort first by variable type, then by variable name
std::vector<tir::Var> params{use_def.undefined_.begin(),
use_def.undefined_.end()};
std::sort(params.begin(), params.end(),
[](const tir::Var &a, const tir::Var &b) {
auto sort_key = [](const tir::Var &var) {
return std::tuple{
!var->dtype.is_handle(),
var->name_hint,
};
};
return sort_key(a) < sort_key(b);
});
return {params, use_def.undefined_buffers_};
}();
// CodeGenCPU is used for some device-side targets, such as
// "ext_dev", and expects to be able to return a int32_t status
// code.
bool can_propagate_errors = [&]() {
auto kind = device_target->GetTargetDeviceType();
return kind == kDLCPU || kind == kDLExtDev || kind == kDLHexagon;
}();
IntImm success(DataType::Int(32), 0);
Type kernel_ret_type;
if (can_propagate_errors) {
kernel_ret_type = PrimType(DataType::Int(32));
body = tir::SeqStmt::Flatten(body, tir::Evaluate(ret(success)));
} else {
kernel_ret_type = VoidType();
}
for (tir::Buffer buf : buffers_to_declare) {
body = tir::DeclBuffer(buf, std::move(body));
}
tir::PrimFunc device_func(params, body, kernel_ret_type);
device_func =
WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target},
{tir::attr::kNoAlias, true},
{tir::attr::kIsGlobalFunc, true}});
GlobalVar kernel_symbol_global = var_supply_();
(*device_mod_)->Add(kernel_symbol_global, device_func);
Array<PrimExpr> args =
params.Map([](const tir::Var &var) -> PrimExpr { return var; });
if (can_propagate_errors) {
tir::Var kernel_error_code("kernel_error_code", success->dtype);
tir::Call kernel_call(success->dtype, kernel_symbol_global, args);
tir::AssertStmt assert_success(
kernel_error_code == success,
tir::StringImm("Error executing compute kernel"), tir::Evaluate(0));
tir::LetStmt let_check(kernel_error_code, kernel_call, assert_success);
return let_check;
} else {
return tir::Evaluate(
tir::Call(DataType::Void(), kernel_symbol_global, args));
}
}
// target ir module
IRModule *device_mod_;
// Generate new GlobalVar for the kernel
std::function<GlobalVar()> var_supply_;
};
tir::PrimFunc SplitHostDevice(tir::PrimFunc func, IRModule *device_mod,
std::function<GlobalVar()> var_supply) {
HostDeviceSplitter splitter(device_mod, std::move(var_supply));
if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
} else if (!splitter.found_device_region()) {
if (auto target = func->GetAttr<Target>(tvm::attr::kTarget)) {
auto device_target = target.value().WithoutHost();
if (device_target.defined() &&
func->HasNonzeroAttr(tir::attr::kIsEntryFunc) &&
tir::is_no_op(func->body)) {
if (auto forced = splitter.ForceSplit(func->body, device_target);
!forced.same_as(func->body)) {
func.CopyOnWrite()->body = forced;
}
}
}
}
return func;
}
namespace transform {
tvm::transform::Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, tvm::transform::PassContext ctx) {
tvm::GlobalVarSupply global_var_supply(mod);
IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
for (const auto &[gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<tir::PrimFunc>()) {
tir::PrimFunc func = opt.value();
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);
auto kernel_name = name_prefix + "_kernel";
auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
return global_var_supply->FreshGlobal(kernel_name, false);
};
func = ::tvm::tl::SplitHostDevice(std::move(func), &device_mod,
var_supply);
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
}
}
mod->Update(updates);
mod->Update(device_mod);
return tir::transform::ConvertSSA()(mod);
};
return tvm::transform::CreateModulePass(pass_func, 0, "tl.SplitHostDevice",
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.SplitHostDevice", SplitHostDevice);
});
} // namespace transform
} // namespace tl
} // namespace tvm
# ruff: noqa
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
@tilelang.jit
def _empty_kernel():
@T.prim_func
def empty_kernel():
with T.Kernel(1, threads=32) as thread_idx:
pass
return empty_kernel
def test_empty_kernel_lowering():
kernel = _empty_kernel()
kernel()
@tilelang.jit
def _empty_with_dead_code_kernel():
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
with T.Kernel(num_tokens, threads=32) as pid:
y = x[pid]
return buggy_kernel
@tilelang.testing.requires_cuda
def test_empty_with_dead_code_kernel():
kernel = _empty_with_dead_code_kernel()
x = torch.randn((128,), dtype=torch.float32, device="cuda")
kernel(x)
@tilelang.jit
def _empty_kernel_with_binding_variants(use_tuple_binding: bool = False):
@T.prim_func
def kernel_with_tuple_kernel_binding():
with T.Kernel(1, threads=32) as (pid,):
print(pid)
pass
@T.prim_func
def kernel_with_scalar_kernel_binding():
with T.Kernel(1, threads=32) as pid:
print(pid)
pass
return kernel_with_tuple_kernel_binding if use_tuple_binding else kernel_with_scalar_kernel_binding
def test_empty_kernel_with_binding_variants():
kernel = _empty_kernel_with_binding_variants()
kernel()
tuple_kernel = _empty_kernel_with_binding_variants(use_tuple_binding=True)
tuple_kernel()
if __name__ == "__main__":
tilelang.testing.main()
......@@ -193,7 +193,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if allow_global_thread_synchronization():
mod = tilelang.transform.ThreadSync("global")(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
......
......@@ -9,6 +9,18 @@ from tvm.ffi import register_object
from tilelang import _ffi_api
import threading
# Ensure single-dimension kernel bindings can be unpacked like iterables.
# especially for issue https://github.com/tile-ai/tilelang/issues/830
if not hasattr(Var, "__iter__"):
def _var_iter(self):
yield self
Var.__iter__ = _var_iter # type: ignore[attr-defined]
if not hasattr(Var, "__len__"):
Var.__len__ = lambda self: 1 # type: ignore[attr-defined]
class FrameStack:
"""
......@@ -68,6 +80,17 @@ def _get_current_stack() -> FrameStack:
return _local.kernel_launch_frame_stack
def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]:
"""
Return a bare Var when we only have a single binding so that users may write either
`with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`.
Otherwise, keep the list semantics for multi-dimensional launches.
"""
if len(bindings) == 1:
return bindings[0]
return bindings
@register_object("tl.KernelLaunchFrame")
class KernelLaunchFrame(TIRFrame):
"""
......@@ -83,9 +106,6 @@ class KernelLaunchFrame(TIRFrame):
"""
super().__enter__()
_get_current_stack().push(self)
# If we have exactly 5 frames, return the single iter_var.var.
if len(self.frames) == 5:
return self.frames[0].iter_var.var
last_block_frame = self.frames[-1]
assert isinstance(last_block_frame,
......@@ -95,11 +115,11 @@ class KernelLaunchFrame(TIRFrame):
if maybe_cpu:
# CPU kernel frame, return a list of for frame items.
return [frame.vars[0] for frame in self.frames[0:-1]]
return _normalize_bindings([frame.vars[0] for frame in self.frames[0:-1]])
else:
# Otherwise, return a list of iter_var.var objects (excluding the last 4 frames).
# As 4 frames for threadIdx.x, threadIdx.y, threadIdx.z and block frame with attributes
return [frame.iter_var.var for frame in self.frames[0:-4]]
return _normalize_bindings([frame.iter_var.var for frame in self.frames[0:-4]])
def __exit__(self, ptype, value, trace):
"""
......@@ -234,6 +254,31 @@ def Kernel(
-------
res : Tuple[frame.LaunchThreadFrame]
The result LaunchThreadFrame.
Examples
--------
Create a 1-D CUDA kernel launch and unpack the single block index:
.. code-block:: python
with T.Kernel(T.ceildiv(N, 128), threads=128) as bx:
# bx is the blockIdx.x binding (also iterable as (bx,))
...
Launch a 2-D grid while requesting two thread dimensions:
.. code-block:: python
with T.Kernel(grid_x, grid_y, threads=(64, 2)) as (bx, by):
tx, ty = T.get_thread_bindings()
...
Emit a CPU kernel where thread bindings are skipped:
.. code-block:: python
with T.Kernel(loop_extent, is_cpu=True) as (i,):
...
"""
attrs: dict = {}
......
......@@ -282,6 +282,17 @@ def AnnotateDeviceRegions():
return _ffi_api.AnnotateDeviceRegions() # type: ignore
def SplitHostDevice():
"""Split host/device functions even for empty kernels.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.SplitHostDevice() # type: ignore
def VectorizeLoop(enable_vectorize: bool = True):
"""VectorizeLoop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment