"tests/L0/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "795a5e5bbac3b3d682c71566b198856c94bf089d"
Unverified Commit b9a51c43 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TMA] Bugfix when a shared buffer is both issued with tma store and tma load (#857)

- Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`.
- Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping.
- Added assertions to ensure consistency between function parameters and arguments during kernel launches.
- Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
parent 058a670b
......@@ -25,7 +25,7 @@ public:
PrimFuncNode *fptr = f.CopyOnWrite();
LowerHopperIntrin substituter(disable_shuffle_elect);
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_desc_arg_map;
Map<Var, Array<PrimExpr>> init_desc_arg_map;
for (const auto &[call, var] : substituter.desc_map_) {
// Should allocate 128 bytes for TensorMap on stack
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
......@@ -46,7 +46,7 @@ public:
Call(DataType::Handle(), builtin::tvm_call_packed(), init_desc_args);
fptr->body =
LetStmt(var, alloc_desc, SeqStmt({Evaluate(init_desc), fptr->body}));
init_desc_arg_map.Set(var->name_hint, init_desc_args);
init_desc_arg_map.Set(var, init_desc_args);
}
f = WithAttr(std::move(f), "tma_descriptor_args", init_desc_arg_map);
return f;
......
......@@ -8,6 +8,7 @@ from .utils import (match_declare_kernel, match_declare_kernel_cpu, is_cuda_targ
import re
import logging
import textwrap
from tvm.tir.stmt_functor import post_order_visit
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY = """
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
......@@ -260,7 +261,11 @@ class TLCUDASourceWrapper(object):
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
def func_call_args(s,
function_args,
function_params,
desc_name_map: Optional[Dict[str, str]] = None,
desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
......@@ -280,8 +285,15 @@ class TLCUDASourceWrapper(object):
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
if arg["name"] == match:
call_args.append(match)
elif maybe_desc(arg["name"], matches, i):
call_args.append(match)
assert len(call_args) <= len(
function_params
), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments"
desc_name_var_map[match] = function_params[len(call_args) - 1]
return call_args
has_l2_persistent_map = False
......@@ -294,10 +306,12 @@ class TLCUDASourceWrapper(object):
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
desc_name_map: Dict[str, str] = {}
desc_name_var_map: Dict[str, tvm.tir.Var] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
dynamic_smem_buf = function_info["dynamic_smem_buf"]
function_params = function_info["function_params"]
# Find the location of the global kernel function in the code
index = match_declare_kernel(code, function_name + "(")
......@@ -321,7 +335,11 @@ class TLCUDASourceWrapper(object):
kernel_launch_code += init_l2_persistent_map
if self.use_cooperative_groups[function_name]:
args_list = func_call_args(declaration, function_args, desc_name_map)
args_list = func_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args
......@@ -329,14 +347,20 @@ class TLCUDASourceWrapper(object):
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str)
else:
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
args_list = func_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map)
init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map,
desc_name_var_map)
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
......@@ -362,15 +386,17 @@ class TLCUDASourceWrapper(object):
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
desc_name_var_map: Dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
for handle_name, _ in desc_name_map.items():
assert handle_name in desc_name_var_map, f"Handle name {handle_name} not found in desc_name_var_map"
desc_var = desc_name_var_map[handle_name]
for handle_name, name in desc_name_map.items():
desc_name = name + "_desc"
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_name]
assert desc_var in self.tma_descriptor_args, f"TMA descriptor {desc_var} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_var]
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
......@@ -536,12 +562,35 @@ class TLCUDASourceWrapper(object):
# Do not update function with dispatch host function
if (function_name not in self.block_info) or (function_name not in self.grid_info):
continue
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params)
function_params: List[str] = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
if isinstance(node, tvm.tir.Call):
if not (hasattr(node, "op") and
node.op == tvm.ir.Op.get("tir.tvm_call_packed")):
return
args = node.args
if not args or args[0] != fn:
return
if len(args) < 1 + param_cnt:
raise AssertionError(
"tvm_call_packed should have at least 1 argument and match device function parameters"
)
function_params = args[1:1 + param_cnt]
post_order_visit(self.host_func.body, visitor)
assert function_params is not None, "function_params should not be None"
function_informations[function_name] = {
"function_name": function_name,
"block_info": self.block_info[function_name],
"grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
"function_params": function_params,
}
# Create the host function wrapper for the CUDA kernel
......@@ -579,6 +628,19 @@ class TLCUDASourceWrapper(object):
return function
raise ValueError("Cannot find primary function in the module.")
@property
def host_func(self):
if len(self.host_mod.get_global_vars()) == 1:
return self.host_mod[self.host_mod.get_global_vars()[0]]
elif "main" in self.host_mod:
return self.host_mod["main"]
else:
for _, function in self.host_mod.functions.items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
"""
......@@ -636,7 +698,6 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
function_args.append({"name": dyn_sym, "type": "ctypes.c_int"})
function_args.append(self.get_stream_type())
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment