Commit fe5f7b3b authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Support duplicate tma desc declaration (#228)

* [Refactor] Enhance argument handling in TLCUDASourceWrapper and TLCPUSourceWrapper

- Updated `func_call_args` to accept an optional `desc_name_map` parameter for improved descriptor handling.
- Modified `generate_tma_descriptor_args` to utilize the `desc_name_map`, ensuring correct mapping of descriptor names.
- Cleaned up the logic for generating function call arguments and TMA descriptor initialization, enhancing code clarity and maintainability.

* lint fix
parent a1da26f2
...@@ -127,13 +127,15 @@ class TLCUDASourceWrapper(object): ...@@ -127,13 +127,15 @@ class TLCUDASourceWrapper(object):
# Format the function arguments for declaration # Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args): def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
# Extract the function call arguments matching the function definition # Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int): def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i] match = matches[i]
if match != name + "_desc": if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False return False
desc_decls = [] desc_decls = []
if desc_name_map is not None:
desc_name_map[match] = name
if i > 0: if i > 0:
desc_decls.append(matches[i - 1]) desc_decls.append(matches[i - 1])
if i < len(matches) - 1: if i < len(matches) - 1:
...@@ -159,7 +161,7 @@ class TLCUDASourceWrapper(object): ...@@ -159,7 +161,7 @@ class TLCUDASourceWrapper(object):
return str(p).replace("//", "/") return str(p).replace("//", "/")
_call_str = """""" _call_str = """"""
_call_str += self.generate_tma_descriptor_args() desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items(): for function_name, function_info in function_informations.items():
block_info = function_info["block_info"] block_info = function_info["block_info"]
grid_info = function_info["grid_info"] grid_info = function_info["grid_info"]
...@@ -173,8 +175,7 @@ class TLCUDASourceWrapper(object): ...@@ -173,8 +175,7 @@ class TLCUDASourceWrapper(object):
# Identify the start of the function body to insert arguments # Identify the start of the function body to insert arguments
index = code.index("{", index) index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
call_args = ", ".join(func_call_args(declaration, function_args))
block_str = "dim3({}, {}, {})".format( block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]), legalize_c(block_info[0]),
...@@ -188,21 +189,27 @@ class TLCUDASourceWrapper(object): ...@@ -188,21 +189,27 @@ class TLCUDASourceWrapper(object):
block_str, smem_str, block_str, smem_str,
call_args) call_args)
_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func return host_func
def generate_tma_descriptor_args(self) -> str: def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
return tma_descripter_init return tma_descripter_init
for _, args in self.tma_descriptor_args.items(): for handle_name, name in desc_name_map.items():
desc_name = name + "_desc"
assert desc_name in self.tma_descriptor_args, f"TMA descriptor {desc_name} not found in {self.tma_descriptor_args}"
args = self.tma_descriptor_args[desc_name]
# Skip __tvm_tensormap_create_tiled # Skip __tvm_tensormap_create_tiled
if len(args) < 3: if len(args) < 3:
raise ValueError( raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3") f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
desc_name, dtype, tensor_rank, globalAddress, *remaining_args = args[1:] _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank) tensor_rank = int(tensor_rank)
# Validate tensor_rank # Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0: if not isinstance(tensor_rank, int) or tensor_rank <= 0:
...@@ -234,7 +241,7 @@ class TLCUDASourceWrapper(object): ...@@ -234,7 +241,7 @@ class TLCUDASourceWrapper(object):
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e ) from e
tma_descripter_init += TMA_DESC_INIT_FUNC.format(desc_name, dtype, tensor_rank, tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim), globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(global_stride),
",".join(box_dim), ",".join(box_dim),
...@@ -341,8 +348,6 @@ class TLCUDASourceWrapper(object): ...@@ -341,8 +348,6 @@ class TLCUDASourceWrapper(object):
"dynamic_smem_buf": self.dynamic_smem_buf[function_name], "dynamic_smem_buf": self.dynamic_smem_buf[function_name],
} }
# TODO(Lei): Sort function_informations by invoke order
# Create the host function wrapper for the CUDA kernel # Create the host function wrapper for the CUDA kernel
host_func = self.create_dispatch_func(code, function_informations) host_func = self.create_dispatch_func(code, function_informations)
# Combine the source, initialization function, and host function to form the complete library code # Combine the source, initialization function, and host function to form the complete library code
...@@ -456,24 +461,12 @@ class TLCPUSourceWrapper(object): ...@@ -456,24 +461,12 @@ class TLCPUSourceWrapper(object):
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args): def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
return False
desc_decls = []
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)" pattern = r"[,\s]*(?:\w+\s*\*+\s*\s+)?(\w+)"
matches = re.findall(pattern, s) matches = re.findall(pattern, s)
call_args = [] call_args = []
for i, match in enumerate(matches): for match in matches:
for arg in function_args: for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i): if arg["name"] == match:
call_args.append(match) call_args.append(match)
return call_args return call_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment