Commit 916ee60e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce wrapper util `pythonic_expr` to transform a PrimExpr...

[Enhancement] Introduce wrapper util `pythonic_expr` to transform a PrimExpr into python string (#577)

* [Feature] Add Quarter Bank Swizzle Layout and Update GEMM Layout Logic

- Introduced a new `makeQuarterBankSwizzleLayout` function for layout swizzling of 32 bytes.
- Updated `makeGemmABLayout` to include an `enable_padding` parameter, allowing for conditional layout selection between padded and quarter bank swizzle layouts.
- Adjusted layout inference in GEMM operations to utilize the new quarter bank swizzle layout when appropriate.
- Enhanced bulk copy operations to recognize and handle the new layout type, improving memory access patterns.

* lint fix

* lint fix

* rebase

* rebase

* typo

* requirement fix

* revert flash atten requirenemts
parent 67d0b677
...@@ -361,6 +361,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -361,6 +361,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
auto stride = as_const_int(shared_layout->InputShape()[0]); auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]); auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr); ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout, if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous, makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) { dst->dtype.bits()))) {
......
...@@ -101,3 +101,35 @@ def get_annotated_mod( ...@@ -101,3 +101,35 @@ def get_annotated_mod(
} }
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.Expr) -> str:
python_str = ""
node_to_str_map = {} # Stores string representation for each node
def _pythonic_visitor(node):
if isinstance(node, tvm.tir.Var):
s = node.name
elif isinstance(node, tvm.tir.IntImm):
# Integer constant: use value directly (ignore type)
s = str(node.value)
elif isinstance(node, tvm.tir.Cast):
# Type cast: skip Cast and use inner value directly
s = node_to_str_map.get(node.value, str(node.value))
elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} * {b_str}"
else:
# Other nodes: use default string representation
s = str(node)
# Store current node's string representation
node_to_str_map[node] = s
nonlocal python_str
python_str = s # Update global string (retain root node in the end)
# Perform post-order traversal
tvm.tir.stmt_functor.post_order_visit(expr, _pythonic_visitor)
return python_str
...@@ -3,7 +3,8 @@ from tilelang import tvm as tvm ...@@ -3,7 +3,8 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any from typing import Optional, List, Dict, Union, Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod from .utils import (match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target,
is_cpu_target, get_annotated_mod, pythonic_expr)
import re import re
import logging import logging
import textwrap import textwrap
...@@ -396,19 +397,10 @@ class TLCUDASourceWrapper(object): ...@@ -396,19 +397,10 @@ class TLCUDASourceWrapper(object):
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
def legalize_c2s(p): global_dim = [pythonic_expr(i) for i in global_dim]
# Convert TIR expressions to legal C expressions global_stride = [pythonic_expr(i) for i in global_stride]
# Directly convert to string since the special case handling box_dim = [pythonic_expr(i) for i in box_dim]
# does not alter the string representation for `tvm.tir.Var` and `IntImm`. element_strides = [pythonic_expr(i) for i in element_strides]
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
global_dim = [legalize_c2s(i) for i in global_dim]
global_stride = [legalize_c2s(i) for i in global_stride]
box_dim = [legalize_c2s(i) for i in box_dim]
element_strides = [legalize_c2s(i) for i in element_strides]
# Extract remaining parameters # Extract remaining parameters
try: try:
...@@ -647,14 +639,6 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -647,14 +639,6 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
call_args.append((match, "None")) call_args.append((match, "None"))
return call_args return call_args
def legalize(p):
# Convert TIR expressions to legal Python expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
desc_name_map: Dict[str, str] = {} desc_name_map: Dict[str, str] = {}
device_index = 0 device_index = 0
kernel_launch_code = """""" kernel_launch_code = """"""
...@@ -681,9 +665,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper): ...@@ -681,9 +665,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args( kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format( desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, legalize(grid_info[0]), legalize(grid_info[1]), function_name, pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
legalize(grid_info[2]), legalize(block_info[0]), legalize(block_info[1]), pythonic_expr(grid_info[2]), pythonic_expr(block_info[0]),
legalize(block_info[2]), smem_str, arg_names, arg_types, device_index) pythonic_expr(block_info[1]), pythonic_expr(
block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format( host_func = PREDEF_HOST_FUNC_PY.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment