"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "f91579aab6e224c23aceaeaa0a29d9dde83f09ed"
Commit 60974197 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Extend pythonic_expr to support dtype mapping in utils.py (#641)

- Updated the `pythonic_expr` function to accept an optional `dtype_map` parameter, allowing for more flexible type conversions.
- Refactored calls to `pythonic_expr` in `TLCUDASourceWrapper` to utilize the new mapping feature, improving type handling in kernel generation.
- Enhanced code clarity by consolidating repeated calls to `pythonic_expr` into a private method within the wrapper class.
parent 156ff85e
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Union, Optional, Literal from typing import Union, Optional, Literal, Dict
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
...@@ -103,7 +103,7 @@ def get_annotated_mod( ...@@ -103,7 +103,7 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr) -> str: def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str:
""" """
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
...@@ -154,7 +154,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr) -> str: ...@@ -154,7 +154,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
elif isinstance(node, tvm.tir.Cast): elif isinstance(node, tvm.tir.Cast):
# C-style cast has high precedence # C-style cast has high precedence
value_str, _ = node_to_result_map[node.value] value_str, _ = node_to_result_map[node.value]
s = f"({node.dtype}){value_str}" if dtype_map is None:
s = f"({node.dtype}){value_str}"
else:
s = f"({dtype_map[node.dtype]}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE) p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance( elif isinstance(
node, node,
......
...@@ -223,6 +223,9 @@ class TLCUDASourceWrapper(object): ...@@ -223,6 +223,9 @@ class TLCUDASourceWrapper(object):
self.libpath: Optional[str] = None self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source) self.lib_code: Optional[str] = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
def is_tma_descriptor_arg(self, arg_name: str) -> bool: def is_tma_descriptor_arg(self, arg_name: str) -> bool:
return arg_name in self.prim_func.buffer_map return arg_name in self.prim_func.buffer_map
...@@ -303,13 +306,13 @@ class TLCUDASourceWrapper(object): ...@@ -303,13 +306,13 @@ class TLCUDASourceWrapper(object):
index = code.index("{", index) index = code.index("{", index)
block_str = "dim3({}, {}, {})".format( block_str = "dim3({}, {}, {})".format(
pythonic_expr(block_info[0]), self._pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]), self._pythonic_expr(block_info[1]),
pythonic_expr(block_info[2]), self._pythonic_expr(block_info[2]),
) )
grid_str = "dim3({}, {}, {})".format( grid_str = "dim3({}, {}, {})".format(
pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2])) self._pythonic_expr(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name) init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
...@@ -352,7 +355,7 @@ class TLCUDASourceWrapper(object): ...@@ -352,7 +355,7 @@ class TLCUDASourceWrapper(object):
# as size_in_bytes maybe a symbolic expression # as size_in_bytes maybe a symbolic expression
num_bytes = persisting_l2_cache_max_size num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format( init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), pythonic_expr(num_bytes)) buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map return init_l2_persistent_map
...@@ -388,10 +391,10 @@ class TLCUDASourceWrapper(object): ...@@ -388,10 +391,10 @@ class TLCUDASourceWrapper(object):
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [pythonic_expr(i) for i in global_dim] global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [pythonic_expr(i) for i in global_stride] global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [pythonic_expr(i) for i in box_dim] box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [pythonic_expr(i) for i in element_strides] element_strides = [self._pythonic_expr(i) for i in element_strides]
# Extract remaining parameters # Extract remaining parameters
try: try:
...@@ -656,9 +659,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -656,9 +659,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args( kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format( desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]), function_name, self._pythonic_expr(grid_info[0]),
pythonic_expr(grid_info[2]), pythonic_expr(block_info[0]), self._pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[2]),
pythonic_expr(block_info[1]), pythonic_expr( self._pythonic_expr(block_info[0]), self._pythonic_expr(block_info[1]),
self._pythonic_expr(
block_info[2]), smem_str, arg_names, arg_types, device_index) block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment