Commit 916ee60e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Introduce wrapper util `pythonic_expr` to transform a PrimExpr...

[Enhancement] Introduce wrapper util `pythonic_expr` to transform a PrimExpr into python string (#577)

* [Feature] Add Quarter Bank Swizzle Layout and Update GEMM Layout Logic

- Introduced a new `makeQuarterBankSwizzleLayout` function for layout swizzling of 32 bytes.
- Updated `makeGemmABLayout` to include an `enable_padding` parameter, allowing for conditional layout selection between padded and quarter bank swizzle layouts.
- Adjusted layout inference in GEMM operations to utilize the new quarter bank swizzle layout when appropriate.
- Enhanced bulk copy operations to recognize and handle the new layout type, improving memory access patterns.

* lint fix

* lint fix

* rebase

* rebase

* typo

* requirement fix

* revert flash atten requirenemts
parent 67d0b677
......@@ -361,6 +361,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
auto stride = as_const_int(shared_layout->InputShape()[0]);
auto continuous = as_const_int(shared_layout->InputShape()[1]);
ICHECK(stride != nullptr && continuous != nullptr);
if (StructuralEqual()(shared_layout,
makeQuarterBankSwizzleLayout(*stride, *continuous,
dst->dtype.bits()))) {
......
......@@ -101,3 +101,35 @@ def get_annotated_mod(
}
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.Expr) -> str:
python_str = ""
node_to_str_map = {} # Stores string representation for each node
def _pythonic_visitor(node):
if isinstance(node, tvm.tir.Var):
s = node.name
elif isinstance(node, tvm.tir.IntImm):
# Integer constant: use value directly (ignore type)
s = str(node.value)
elif isinstance(node, tvm.tir.Cast):
# Type cast: skip Cast and use inner value directly
s = node_to_str_map.get(node.value, str(node.value))
elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} * {b_str}"
else:
# Other nodes: use default string representation
s = str(node)
# Store current node's string representation
node_to_str_map[node] = s
nonlocal python_str
python_str = s # Update global string (retain root node in the end)
# Perform post-order traversal
tvm.tir.stmt_functor.post_order_visit(expr, _pythonic_visitor)
return python_str
......@@ -3,7 +3,8 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union, Any
from tvm import IRModule
from tvm.target import Target
from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod
from .utils import (match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target,
is_cpu_target, get_annotated_mod, pythonic_expr)
import re
import logging
import textwrap
......@@ -396,19 +397,10 @@ class TLCUDASourceWrapper(object):
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
def legalize_c2s(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
global_dim = [legalize_c2s(i) for i in global_dim]
global_stride = [legalize_c2s(i) for i in global_stride]
box_dim = [legalize_c2s(i) for i in box_dim]
element_strides = [legalize_c2s(i) for i in element_strides]
global_dim = [pythonic_expr(i) for i in global_dim]
global_stride = [pythonic_expr(i) for i in global_stride]
box_dim = [pythonic_expr(i) for i in box_dim]
element_strides = [pythonic_expr(i) for i in element_strides]
# Extract remaining parameters
try:
......@@ -647,14 +639,6 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
call_args.append((match, "None"))
return call_args
def legalize(p):
# Convert TIR expressions to legal Python expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
desc_name_map: Dict[str, str] = {}
device_index = 0
kernel_launch_code = """"""
......@@ -681,9 +665,10 @@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += self.generate_tma_descriptor_args(
desc_name_map) + KERNEL_LAUNCH_FUNC_PY.format(
function_name, legalize(grid_info[0]), legalize(grid_info[1]),
legalize(grid_info[2]), legalize(block_info[0]), legalize(block_info[1]),
legalize(block_info[2]), smem_str, arg_names, arg_types, device_index)
function_name, pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2]), pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]), pythonic_expr(
block_info[2]), smem_str, arg_names, arg_types, device_index)
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC_PY.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment