Commit 60974197 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Extend pythonic_expr to support dtype mapping in utils.py (#641)

- Updated the `pythonic_expr` function to accept an optional `dtype_map` parameter, allowing for more flexible type conversions.
- Refactored calls to `pythonic_expr` in `TLCUDASourceWrapper` to utilize the new mapping feature, improving type handling in kernel generation.
- Enhanced code clarity by consolidating repeated calls to `pythonic_expr` into a private method within the wrapper class.
parent 156ff85e
from __future__ import annotations
import re
from typing import Union, Optional, Literal
from typing import Union, Optional, Literal, Dict
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
......@@ -103,7 +103,7 @@ def get_annotated_mod(
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
......@@ -154,7 +154,10 @@ def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
elif isinstance(node, tvm.tir.Cast):
# C-style cast has high precedence
value_str, _ = node_to_result_map[node.value]
s = f"({node.dtype}){value_str}"
if dtype_map is None:
s = f"({node.dtype}){value_str}"
else:
s = f"({dtype_map[node.dtype]}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance(
node,
......
......@@ -223,6 +223,9 @@ class TLCUDASourceWrapper(object):
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
def is_tma_descriptor_arg(self, arg_name: str) -> bool:
return arg_name in self.prim_func.buffer_map
......@@ -303,13 +306,13 @@ class TLCUDASourceWrapper(object):
index = code.index("{", index)
block_str = "dim3({}, {}, {})".format(
pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]),
pythonic_expr(block_info[2]),
self._pythonic_expr(block_info[0]),
self._pythonic_expr(block_info[1]),
self._pythonic_expr(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2]))
self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]),
self._pythonic_expr(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
......@@ -352,7 +355,7 @@ class TLCUDASourceWrapper(object):
# as size_in_bytes maybe a symbolic expression
num_bytes = persisting_l2_cache_max_size
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), pythonic_expr(num_bytes))
buffer_name, float(hit_ratio), self._pythonic_expr(num_bytes))
return init_l2_persistent_map
......@@ -388,10 +391,10 @@ class TLCUDASourceWrapper(object):
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [pythonic_expr(i) for i in global_dim]
global_stride = [pythonic_expr(i) for i in global_stride]
box_dim = [pythonic_expr(i) for i in box_dim]
element_strides = [pythonic_expr(i) for i in element_strides]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]
# Extract remaining parameters
try:
......@@ -656,9 +659,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2]), pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]), pythonic_expr(
function_name, self._pythonic_expr(grid_info[0]),
self._pythonic_expr(grid_info[1]), self._pythonic_expr(grid_info[2]),
self._pythonic_expr(block_info[0]), self._pythonic_expr(block_info[1]),
self._pythonic_expr(
block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment