Commit fe5f7b3b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Support duplicate tma desc declaration (#228)

* [Refactor] Enhance argument handling in TLCUDASourceWrapper and TLCPUSourceWrapper

- Updated `func_call_args` to accept an optional `desc_name_map` parameter for improved descriptor handling.
- Modified `generate_tma_descriptor_args` to utilize the `desc_name_map`, ensuring correct mapping of descriptor names.
- Cleaned up the logic for generating function call arguments and TMA descriptor initialization, enhancing code clarity and maintainability.

* lint fix
parent a1da26f2
......@@ -127,13 +127,15 @@ class TLCUDASourceWrapper(object):
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
......@@ -159,7 +161,7 @@ class TLCUDASourceWrapper(object):
return str(p).replace("//", "/")
_call_str = """"""
_call_str += self.generate_tma_descriptor_args()
desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
......@@ -173,8 +175,7 @@ class TLCUDASourceWrapper(object):
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
......@@ -188,21 +189,27 @@ class TLCUDASourceWrapper(object):
block_str, smem_str,
call_args)
_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func
def generate_tma_descriptor_args(self) -> str:
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
for _, args in self.tma_descriptor_args.items():
for handle_name, name in desc_name_map.items():
desc_name = name + "_desc"
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_name]
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
desc_name, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank)
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
......@@ -234,7 +241,7 @@ class TLCUDASourceWrapper(object):
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_DESC_INIT_FUNC.format(desc_name, dtype, tensor_rank,
tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim),
",".join(global_stride),
",".join(box_dim),
......@@ -341,8 +348,6 @@ class TLCUDASourceWrapper(object):
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
# TODO(Lei): Sort function_informations by invoke order
# Create the host function wrapper for the CUDA kernel
host_func = self.create_dispatch_func(code, function_informations)
# Combine the source, initialization function, and host function to form the complete library code
......@@ -456,24 +461,12 @@ class TLCPUSourceWrapper(object):
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
return False
desc_decls = []
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for match in matches:
for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
if arg["name"] == match:
call_args.append(match)
return call_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment