Commit 22aed721 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support more flexible layout host pythonic expr (#623)

* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr

- Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation.
- Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code.

* [Refactor] Simplify expression handling in pythonic_expr function

- Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability.
- Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation.

* [Enhancement] Improve expression representation in pythonic_expr function

- Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings.
- Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation.
- Included a docstring for better clarity on the function's purpose and usage.

* test fix
parent 5101e6bc
......@@ -3,9 +3,6 @@ import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
tilelang.testing.set_random_seed(0)
tilelang.disable_cache()
@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
......
......@@ -27,6 +27,6 @@ setuptools
einops
attrs
decorator
flash-attn
flash-attn<=2.8.0
scipy
tornado
\ No newline at end of file
......@@ -104,37 +104,108 @@ def get_annotated_mod(
def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
Args:
expr: The TVM PrimExpr to convert.
Returns:
A string representation of the expression.
"""
if not isinstance(expr, tvm.tir.PrimExpr):
return str(expr)
python_str = ""
node_to_str_map = {} # Stores string representation for each node
def _pythonic_visitor(node):
# 1. Define operator precedence (higher value means higher precedence)
# Based on Python's operator precedence
PRECEDENCE = {
tvm.tir.Call: 20, # Includes min, max
tvm.tir.Cast: 20, # Treated like a function call
tvm.tir.Mul: 13,
tvm.tir.FloorDiv: 13,
tvm.tir.Div: 13, # For tvm.tir.Div if it appears
tvm.tir.FloorMod: 13,
tvm.tir.Add: 12,
tvm.tir.Sub: 12,
tvm.tir.LT: 10,
tvm.tir.LE: 10,
tvm.tir.GT: 10,
tvm.tir.GE: 10,
tvm.tir.EQ: 10,
tvm.tir.NE: 10,
tvm.tir.And: 5,
tvm.tir.Or: 4,
# Atoms (Var, IntImm) have the highest precedence implicitly
}
# By default, atomic expressions (variables, constants) have the highest precedence
ATOMIC_PRECEDENCE = 100
node_to_result_map = {} # Stores (string, precedence) for each node
def _visitor(node):
# 2. Visitor returns (str, precedence) tuple
if node in node_to_result_map:
return
if isinstance(node, tvm.tir.Var):
s = node.name
s, p = node.name, ATOMIC_PRECEDENCE
elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)):
# Integer constant: use value directly (ignore type)
s = str(node.value)
s, p = str(node.value), ATOMIC_PRECEDENCE
elif isinstance(node, tvm.tir.Cast):
# Type cast: represent as (type)value
dtype_map = {"int64": "int64_t", "int32": "int32_t", "int8": "int8_t"}
dtype = dtype_map.get(str(node.dtype), str(node.dtype))
value_str = node_to_str_map.get(node.value, str(node.value))
s = f"({dtype}){value_str}"
elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} * {b_str}"
# C-style cast has high precedence
value_str, _ = node_to_result_map[node.value]
s = f"({node.dtype}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance(
node,
(tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT,
tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)):
op_map = {
tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/",
tvm.tir.Add: "+",
tvm.tir.Sub: "-",
tvm.tir.FloorMod: "%",
tvm.tir.LT: "<",
tvm.tir.LE: "<=",
tvm.tir.GT: ">",
tvm.tir.GE: ">=",
tvm.tir.EQ: "==",
tvm.tir.NE: "!=",
tvm.tir.And: "and",
tvm.tir.Or: "or",
}
op_str = f" {op_map[type(node)]} "
my_precedence = PRECEDENCE[type(node)]
a_str, a_precedence = node_to_result_map[node.a]
b_str, b_precedence = node_to_result_map[node.b]
# 3. Add parentheses intelligently
# Add parentheses if the left operand's precedence is lower than the current operator
if a_precedence < my_precedence:
a_str = f"({a_str})"
# Add parentheses if the right operand's precedence is lower than or equal to the current operator
# 'Equal' is to handle non-associative operations, e.g., a - (b - c)
if b_precedence <= my_precedence:
b_str = f"({b_str})"
s = f"{a_str}{op_str}{b_str}"
p = my_precedence
elif isinstance(node, (tvm.tir.Min, tvm.tir.Max)):
op_name = "min" if isinstance(node, tvm.tir.Min) else "max"
a_str, _ = node_to_result_map[node.a]
b_str, _ = node_to_result_map[node.b]
s = f"{op_name}({a_str}, {b_str})"
# Function calls have high precedence
p = PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE)
else:
# Other nodes: use default string representation
s = str(node)
# Fallback for unhandled expression types
s, p = str(node), 0
# Store current node's string representation
node_to_str_map[node] = s
nonlocal python_str
python_str = s # Update global string (retain root node in the end)
node_to_result_map[node] = (s, p)
# Perform post-order traversal
tvm.tir.stmt_functor.post_order_visit(expr, _pythonic_visitor)
return python_str
tvm.tir.stmt_functor.post_order_visit(expr, _visitor)
return next(iter(node_to_result_map[expr]), "")
......@@ -278,15 +278,6 @@ class TLCUDASourceWrapper(object):
call_args.append(match)
return call_args
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
has_l2_persistent_map = False
for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map:
......@@ -312,12 +303,13 @@ class TLCUDASourceWrapper(object):
index = code.index("{", index)
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]),
pythonic_expr(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
......@@ -891,15 +883,6 @@ class TLCPUSourceWrapper(object):
call_args.append(match)
return call_args
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
_call_str = """"""
for function_name, _ in function_informations.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment