Commit 7b74bb01 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[JIT] Enhance cython/ctypes wrapper for tma descriptor (#126)



* refactor code

* enhance tutorial

* Enhance error handling and code generation in CUDA and TileLang components

This commit introduces several improvements across multiple files:
- Added more informative error messages in GEMM layout checks
- Updated CUDA codegen to support more flexible function signature generation
- Improved TMA descriptor initialization and kernel dispatch logic
- Refined library generation and source code parsing utilities
- Enhanced error handling in various adapter and wrapper classes

* Add thread tag validation for warp specialization

Introduce a ThreadTagChecker to validate that a PrimFunc only uses threadIdx.x before applying warp specialization. This prevents unintended transformations on kernels with complex thread binding and provides a clear warning to users about potential issues with warp specialization.

* Update TileLang Profiling and Compilation in Flash Decoding Examples

Refactor the profiling and compilation workflow in two flash decoding example scripts:
- Replace `tilelang.lower()` and `tilelang.Profiler()` with `tilelang.compile()`
- Simplify profiler initialization using `get_profiler()`
- Update method calls to use the new profiler and compiled kernel objects
- Maintain existing performance benchmarking and validation logic

* Refactor and clean up code formatting in TileLang testing and adapter modules

This commit includes several code style and formatting improvements:
- Adjust whitespace and line breaks in test files
- Improve code formatting in CUDA source wrapper and adapter utilities
- Enhance readability of function calls and argument handling
- Remove unnecessary whitespace and standardize indentation
- Simplify function signatures and argument parsing

* Refactor CUDA codegen and improve code formatting

This commit includes several improvements to CUDA code generation and formatting:
- Enhance function signature generation in CodeGenTileLangCUDA
- Improve code formatting and readability in CUDA-related files
- Simplify parameter handling and type annotations
- Clean up whitespace and line breaks in codegen and layout files

---------
Co-authored-by: default avatarUbuntu <dlisuser@h100testl730RPS.xu5snccwrbtejcqqalluoku5hb.xx.internal.cloudapp.net>
parent ba311311
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import re import re
from typing import Union, Optional from typing import Union, Optional, Literal
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
from tilelang.engine.lower import ( from tilelang.engine.lower import (
is_device_call, get_device_call,
get_host_call,
determine_target, determine_target,
canon_target_host, canon_target_host,
is_cpu_device_backend,
) )
from tilelang.engine.phase import ( from tilelang.engine.phase import (
LowerAndLegalize, LowerAndLegalize,
...@@ -16,11 +18,24 @@ from tilelang.engine.phase import ( ...@@ -16,11 +18,24 @@ from tilelang.engine.phase import (
) )
def match_global_kernel(source: str) -> int: def match_global_kernel(source: str, annotation: str = "__global__") -> int:
pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+"
matched = re.findall(pattern, source) for line in source.split("\n"):
assert len(matched) >= 1 # may have statement before kernel if annotation in line:
return source.index(matched[0]) matched = re.findall(pattern, line)
if len(matched) >= 1:
return source.index(matched[0])
raise ValueError("No global kernel found in the source code")
def match_declare_kernel(source: str, annotation: str = "__global__") -> int:
pattern = r"__global__\s+void\s+\w+"
for line in source.split("\n"):
if annotation in line:
matched = re.findall(pattern, line)
if len(matched) >= 1:
return source.index(matched[0] + "(")
raise ValueError("No global kernel found in the source code")
def is_cuda_target(target: Target) -> bool: def is_cuda_target(target: Target) -> bool:
...@@ -31,28 +46,44 @@ def is_hip_target(target: Target) -> bool: ...@@ -31,28 +46,44 @@ def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip" return target.kind.name == "hip"
def get_annotated_device_mod( def get_annotated_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None, target_host: Optional[Union[str, Target]] = None,
) -> "IRModule": model_type: Literal["device", "host", "all"] = "all",
) -> Union[IRModule, tuple[IRModule, IRModule]]:
# Validate model_type early
if model_type not in {"device", "host", "all"}:
raise ValueError(f"Invalid model type: {model_type}")
# Convert PrimFunc to IRModule if needed
mod = func_or_mod mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc): if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod mod = tvm.IRModule({func_or_mod.attrs["global_symbol"]: func_or_mod})
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# Handle target and target_host
if isinstance(target, str): if isinstance(target, str):
target = determine_target(target) target = determine_target(target)
target_host = tvm.target.Target.canon_target(canon_target_host(target, target_host))
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host) target = tvm.target.Target(target, target_host)
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))
# Apply transformations
mod = LowerAndLegalize(mod, target) mod = LowerAndLegalize(mod, target)
mod = OptimizeForTarget(mod, target) mod = OptimizeForTarget(mod, target)
device_mod = tir.transform.Filter(is_device_call)(mod) # Define dispatch dictionary for different model types
dispatch = {
"device":
lambda m: tir.transform.Filter(_is_device_call)(m),
"host":
lambda m: tir.transform.Filter(_is_host_call)(m),
"all":
lambda m: (tir.transform.Filter(_is_device_call)(m), tir.transform.Filter(_is_host_call)
(m)),
}
return device_mod return dispatch[model_type](mod)
...@@ -6,7 +6,7 @@ from tilelang import tvm as tvm ...@@ -6,7 +6,7 @@ from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import match_global_kernel, is_cuda_target, is_hip_target, get_annotated_device_mod from .utils import match_declare_kernel, is_cuda_target, is_hip_target, get_annotated_mod
import re import re
import logging import logging
...@@ -26,6 +26,27 @@ extern "C" void call({}) {{ ...@@ -26,6 +26,27 @@ extern "C" void call({}) {{
}} }}
""" """
TMA_DESC_INIT_FUNC = """
\tCUtensorMap {0};
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
\tcuuint32_t {0}_tensorRank= {2};
\tvoid *{0}_globalAddress= {3};
\tcuuint64_t {0}_globalDim[{2}]= {{{4}}};
\tcuuint64_t {0}_globalStride[{2}]= {{{5}}};
\tcuuint32_t {0}_boxDim[{2}]= {{{6}}};
\tcuuint32_t {0}_elementStrides[{2}]= {{{7}}};
\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){8};
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){9};
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){10};
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){11};
\tCUresult {0}_result = cuTensorMapEncodeTiled(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\tif ({0}_result != CUDA_SUCCESS) {{
\t\tprintf("Failed to initialize the TMA descriptor {0} with error code %d\\n", {0}_result);
\t\texit(-1);
\t}}
"""
class BaseWrapper(ABC): class BaseWrapper(ABC):
...@@ -61,31 +82,203 @@ class TLCUDASourceWrapper(object): ...@@ -61,31 +82,203 @@ class TLCUDASourceWrapper(object):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.function_name: Optional[str] = None self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1] self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1] self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.tma_descriptor_args: Optional[Dict] = None
self.parse_source_information() self.parse_source_information()
self.srcpath: Optional[str] = None self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source) self.lib_code: Optional[str] = self.update_lib_code(source)
def is_tma_descriptor_arg(self, arg_name: str) -> bool:
return arg_name in self.prim_func.buffer_map
def create_dispatch_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
function_args = []
# Collect function arguments based on primary function's parameters and buffer mappings
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
# Add dynamic symbols as integer arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
def maybe_desc(name: str, matches: List[str], i: int):
match = matches[i]
if match != name + "_desc":
return False
desc_decls = []
if i > 0:
desc_decls.append(matches[i - 1])
if i < len(matches) - 1:
desc_decls.append(matches[i + 1])
return any([decl == "CUtensorMap" for decl in desc_decls])
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for i, match in enumerate(matches):
for arg in function_args:
if arg["name"] == match or maybe_desc(arg["name"], matches, i):
call_args.append(match)
return call_args
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
_call_str = """"""
_call_str += self.generate_tma_descriptor_args()
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
dynamic_smem_buf = function_info["dynamic_smem_buf"]
# Find the location of the global kernel function in the code
index = match_declare_kernel(code, function_name + "(")
# Analyze the function declaration to prepare for argument extraction
declaration = code[index:].split(";")[0]
# Identify the start of the function body to insert arguments
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
_call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str,
block_str, smem_str,
call_args)
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func
def generate_tma_descriptor_args(self) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
for _, args in self.tma_descriptor_args.items():
# Skip __tvm_tensormap_create_tiled
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
desc_name, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tensor_rank = int(tensor_rank)
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [str(i) for i in global_dim]
global_stride = [str(i) for i in global_stride]
box_dim = [str(i) for i in box_dim]
element_strides = [str(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_DESC_INIT_FUNC.format(desc_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim),
",".join(global_stride),
",".join(box_dim),
",".join(element_strides), interleave,
swizzle, l2Promotion, oobFill)
return tma_descripter_init
def parse_source_information(self): def parse_source_information(self):
device_mod = get_annotated_device_mod(self.mod, self.target) device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) == 1 assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
), "Only support one function in the module for static shape kernel." assert (len(host_mod.functions) == 1), "Only support one function in host module."
block_info_map = {}
grid_info_map = {}
dynamic_smem_buf_map = {}
function_names = []
for g_var, func in device_mod.functions.items(): for g_var, func in device_mod.functions.items():
self.function_name = g_var.name_hint # Default block and grid configurations
block_info = [1, 1, 1]
grid_info = [1, 1, 1]
function_name = g_var.name_hint
attrs = func.attrs attrs = func.attrs
dynamic_smem_buf = None
if "dyn_shared_memory_buf" in attrs: if "dyn_shared_memory_buf" in attrs:
self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs: if "thread_extent" in attrs:
# Extract block and grid sizes from thread extents
thread_extent = attrs["thread_extent"] thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items(): for tag, extent in thread_extent.items():
if "threadIdx" in tag: if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag: elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent grid_info["xyz".index(tag[-1])] = extent
# Map the extracted configurations to each function
block_info_map[function_name] = block_info
grid_info_map[function_name] = grid_info
dynamic_smem_buf_map[function_name] = dynamic_smem_buf
function_names.append(function_name)
# Store the mappings for use in code generation
self.block_info = block_info_map
self.grid_info = grid_info_map
self.dynamic_smem_buf = dynamic_smem_buf_map
function_names_index = {}
for _, func in host_mod.functions.items():
if "tma_descriptor_args" in func.attrs:
self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
host_code = str(func)
for function_name in function_names:
index = host_code.index(f'T.call_packed("{function_name}"')
function_names_index[function_name] = index
# sort function_names
function_names = sorted(function_names, key=lambda x: function_names_index[x])
self.function_names = function_names
def get_dynamic_symbolic_set(self, prim_func): def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function # Determine the set of dynamic symbols used in the function
...@@ -101,10 +294,11 @@ class TLCUDASourceWrapper(object): ...@@ -101,10 +294,11 @@ class TLCUDASourceWrapper(object):
# Initialize an empty string for the CUDA function call # Initialize an empty string for the CUDA function call
call_str = """""" call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None: for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
call_str = ( if dynamic_smem_buf is not None:
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, # Format the cudaFuncSetAttribute call for dynamic shared memory
self.dynamic_smem_buf)) call_str += PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(
function_name, dynamic_smem_buf)
# Format the initialization function using the call_str # Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str) init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs return init_funcs
...@@ -112,78 +306,29 @@ class TLCUDASourceWrapper(object): ...@@ -112,78 +306,29 @@ class TLCUDASourceWrapper(object):
def update_lib_code(self, code: str): def update_lib_code(self, code: str):
# Update the library code with the given code string # Update the library code with the given code string
self.lib_code = code self.lib_code = code
# Find the index of the global kernel function in the code # Get the function names
index = match_global_kernel(code) function_names = self.function_names
# Extract the declaration of the function starting from the found index
declaration = code[index:].split(";")[0]
function_name = self.function_name
# Get the CUDA initialization function # Get the CUDA initialization function
init_func = self.get_cuda_init_func() init_func = self.get_cuda_init_func()
# Locate the opening brace of the function to insert arguments # Organize function information for code generation
index = code.index("{", index) function_informations = {}
function_args = [] for function_name in function_names:
# Populate the function arguments from the primary function's parameters and buffers # Do not update function with dispatch host function
for param in self.prim_func.params: if (function_name not in self.block_info) or (function_name not in self.grid_info):
buffer = self.prim_func.buffer_map[param] continue
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) function_informations[function_name] = {
# Add dynamic symbolic parameters as integers to the function arguments "function_name": function_name,
for dyn_sym in dynamic_symbolic_set: "block_info": self.block_info[function_name],
function_args.append({"name": dyn_sym, "type": "int"}) "grid_info": self.grid_info[function_name],
"dynamic_smem_buf": self.dynamic_smem_buf[function_name],
}
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) # TODO(Lei): Sort function_informations by invoke order
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(declaration, function_args))
block_info, grid_info = self.block_info, self.grid_info
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
# Prepare the block and grid dimensions for the CUDA kernel launch
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
call_str = ""
if len(dynamic_symbolic_set) != 0:
call_str += "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str += ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel # Create the host function wrapper for the CUDA kernel
host_func = PREDEF_HOST_FUNC.format(def_args, call_str) host_func = self.create_dispatch_func(code, function_informations)
# Combine the source, initialization function, and host function to form the complete library code # Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func lib_code = self.source + init_func + host_func
return lib_code return lib_code
......
...@@ -127,7 +127,8 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -127,7 +127,8 @@ class Profiler(TorchDLPackKernelAdapter):
elif profiler == "tvm": elif profiler == "tvm":
if func is None: if func is None:
func = self.mod func = self.mod
assert isinstance(func, tvm.runtime.Module), "func should be a TVM module" assert isinstance(
func, tvm.runtime.Module), f"func should be a TVM module, but got {type(func)}"
ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors) ins = (self._get_inputs(with_output=True) if input_tensors is None else input_tensors)
target = "cuda" target = "cuda"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment