Commit f47b43c5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Refactor CUDA post-processing callback registration in TileLang (#259)

* Add GPU kernel for 2D continuous cumulative sum in TileLang example

- Introduced a new example script `example_tilelang_cumsum.py` that generates a GPU kernel for 2D continuous cumulative sum.
- Implemented functions to handle kernel configuration, memory allocation, and inclusive scan operations.
- Added a main execution block to demonstrate the kernel's functionality using PyTorch for tensor operations.
- Enhanced the example with error handling for power-of-two configurations and validation of results against PyTorch's built-in cumulative sum function.

* Refactor TileLang examples and enhance kernel compilation

- Updated `example_tilelang_cumsum.py` to improve GPU kernel generation for 2D continuous cumulative sum, including better parameter handling and error checking.
- Refactored `example_mha_bwd.py` to enhance kernel compilation readability and maintainability.
- Modified `kernel_cache.py` to prevent saving kernels to disk when using the DLPack backend, ensuring proper cache management.
- Added `get_block_bindings` function to `kernel.py` for improved access to block bindings in kernel launch frames.
- Cleaned up import statements in `__init__.py` for better organization and clarity.

* Enhance GPU kernel for 2D continuous cumulative sum in TileLang example

- Added additional spacing for improved readability in `example_tilelang_cumsum.py`.
- Refined kernel structure to enhance clarity and maintainability during GPU kernel generation for cumulative sum operations.

* Refactor CUDA post-processing callback registration in TileLang

- Introduced a new decorator `register_cuda_postproc_callback` for registering CUDA post-processing functions, enhancing usability and flexibility.
- Updated existing callback implementations to utilize the new decorator, improving code clarity and maintainability.
- Added debug prints to the CUDA code generation process for better traceability during development.
- Refactored the `OptimizeForTarget` function to streamline conditional statement handling in the pipeline transformation.
- Cleaned up the `inject_pipeline.cc` file by removing redundant code related to statement grouping and condition handling.

* lint fix

* Enhance BlockSparse GEMM Example with Autotuning and Configurable Parameters

- Added argument parsing to allow dynamic configuration of matrix dimensions and sparsity ratio.
- Implemented a function to generate various kernel configurations for autotuning.
- Refactored the main execution block to support both autotuned and default configurations.
- Improved the block mask generation to accommodate specified sparsity levels.
- Updated the kernel compilation process to utilize the new configurations and ensure accurate results verification.
parent cd9ec62e
......@@ -108,10 +108,22 @@ Hence, by registering a Python function named `tilelang_callback_cuda_postproc`,
import tilelang
import tilelang.language as T
from tilelang import tvm
from tilelang.engine.callback import register_cuda_postproc_callback
@tvm.register_func
@register_cuda_postproc_callback
def tilelang_callback_cuda_postproc(code, _):
# ...existing code...
print(code) # print the final CUDA code
code = "// modified by tilelang_callback_cuda_postproc\n" + code
return code
kernel = tilelang.compile(matmul, target="cuda")
kernel_source = kernel.get_kernel_source()
print(kernel_source)
'''
// modified by tilelang_callback_cuda_postproc
#include "cuda_runtime.h"
...
'''
```
### Runtime Debug Prints with `T.print`
......
import argparse
import itertools
import tilelang
import tilelang.language as T
import torch
torch.random.manual_seed(0)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
from tilelang.autotuner import autotune, jit
def get_configs(M, N, K):
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(block_M, block_N, block_K, num_stages, thread_num, enable_rasterization))
return [{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5],
} for c in _configs]
def ref_program(A, B, BlockMask, C):
batch_M = A.shape[0] // block_M
batch_N = B.shape[1] // block_N
batch_K = A.shape[1] // block_K
for i in range(batch_M):
for j in range(batch_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(batch_K):
if BlockMask[i, j, k]:
accu += A[i*block_M:(i+1)*block_M, k*block_K:(k+1)*block_K].to(torch.float32) @ \
B[k*block_K:(k+1)*block_K, j*block_N:(j+1)*block_N].to(torch.float32)
C[i * block_M:(i + 1) * block_M, j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
def get_best_config(M, N, K):
@autotune(
configs=get_configs(M, N, K),
keys=["block_M", "block_N", "block_K", "num_stages", "thread_num", "enable_rasteration"],
warmup=3,
rep=20,
)
@jit(out_idx=[-1], ref_prog=ref_program)
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None):
return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num,
enable_rasteration)
return kernel()
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
......@@ -16,51 +84,69 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
BlockMask: T.Buffer(block_mask_shape, "bool"),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=2):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if BlockMask[by, bx, k]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
# block_mask = torch.zeros(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# block_mask = torch.ones(1024 // 128, 1024 // 128, 1024 // 32).cuda().bool()
# random mask
block_mask = torch.randint(0, 2, (1024 // 128, 1024 // 128, 1024 // 32)).cuda().bool()
c = kernel(a, b, block_mask)
ref_c = torch.zeros_like(c)
for i in range(1024 // 128):
for j in range(1024 // 128):
accu = torch.zeros((128, 128), dtype=torch.float32, device=a.device)
for k in range(1024 // 32):
if block_mask[i, j, k]:
accu += (
a[i * 128:(i + 1) * 128, k * 32:(k + 1) * 32].to(torch.float32)
@ b[k * 32:(k + 1) * 32, j * 128:(j + 1) * 128].to(torch.float32))
ref_c[i * 128:(i + 1) * 128, j * 128:(j + 1) * 128] = accu.to(torch.float16)
# ref_c = a @ b
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print(kernel.get_kernel_source())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark")
parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)")
parser.add_argument(
"--use_autotune", action="store_true", default=False, help="Whether to use autotune")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
# Initialize input matrices
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
if args.use_autotune:
best_latency, best_config, ref_latency = get_best_config(M, N, K)
func = blocksparse_matmul(M, N, K, *best_config)
else:
func = blocksparse_matmul(M, N, K, 128, 128, 32, 2, 128, True)
# Create block mask with desired sparsity
block_M, block_N, block_K = 128, 128, 32 # default values if not using autotune
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > args.sparsity
kernel = tilelang.compile(func, out_idx=-1)
c = kernel(a, b, block_mask)
# Verify result
ref_c = torch.zeros_like(c)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=a.device)
for k in range(K // block_K):
if block_mask[i, j, k]:
accu += (
a[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ b[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
......@@ -180,8 +180,6 @@ def matmul(M,
return main
return kernel()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
......
......@@ -724,122 +724,16 @@ private:
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
// Group blocks by their predicate conditions
PrimExpr current_condition = Bool(true);
Array<Stmt> current_stmts;
Array<PrimExpr> ordered_conditions;
Array<Array<Stmt>> condition_to_stmts;
for (const auto &stmt : stmts) {
if (const auto *realize = stmt.as<BlockRealizeNode>()) {
// Helper function to find IfThenElse through potential AttrStmt nodes
auto find_if_then_else =
[](Stmt body) -> std::pair<bool, const IfThenElseNode *> {
while (true) {
if (const auto *if_node = body.as<IfThenElseNode>()) {
return {true, if_node};
} else if (const auto *attr_node = body.as<AttrStmtNode>()) {
// Continue traversing through attributes
body = attr_node->body;
} else {
// No IfThenElse found
return {false, nullptr};
}
}
};
auto [has_if, if_then_else] = find_if_then_else(realize->block->body);
if (has_if) {
if (if_then_else->else_case.defined()) {
// IfThenElse nodes with else case are treated individually
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
} else {
// If we encounter a new condition
if (!StructuralEqual()(if_then_else->condition,
current_condition)) {
// Store the current group if it's not empty
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = if_then_else->condition;
}
BlockRealize new_realize = Downcast<BlockRealize>(stmt);
new_realize.CopyOnWrite()->block.CopyOnWrite()->body =
replace_if_then_else(new_realize->block->body,
if_then_else->condition);
current_stmts.push_back(new_realize);
}
} else {
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
}
} else {
// Non-BlockRealize statements are treated individually
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
current_stmts = {};
}
current_condition = Bool(true);
current_stmts.push_back(stmt);
}
}
// Add the last group if not empty
if (!current_stmts.empty()) {
ordered_conditions.push_back(current_condition);
condition_to_stmts.push_back(current_stmts);
}
// Build the final statement sequence with proper conditionals
Array<Stmt> final_stmts;
for (auto i = 0; i < ordered_conditions.size(); i++) {
Array<Stmt> condition_stmts = condition_to_stmts[i];
if (condition_stmts.empty())
continue;
// Create a sequence from the statements with this condition
Stmt stmt_block;
if (condition_stmts.size() == 1) {
stmt_block = condition_stmts[0];
} else {
stmt_block = SeqStmt(condition_stmts);
}
// If condition is not trivially true, wrap in if-then-else
if (!is_one(ordered_conditions[i]) &&
!analyzer_.CanProve(ordered_conditions[i] == true)) {
stmt_block = IfThenElse(ordered_conditions[i], stmt_block);
}
final_stmts.push_back(stmt_block);
}
// Use final_stmts instead of the original stmts
Stmt new_loop{nullptr};
if (final_stmts.empty()) {
if (stmts.empty()) {
return make_nop();
}
if (final_stmts.size() == 1) {
new_loop = final_stmts[0];
if (stmts.size() == 1) {
new_loop = stmts[0];
} else {
new_loop = SeqStmt(final_stmts);
new_loop = SeqStmt(stmts);
}
if (!is_unit_loop) {
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang
from tilelang.engine.callback import register_cuda_postproc_callback
import torch
......@@ -85,7 +86,7 @@ def run_gemm(
stramp = "&*(XS)"
@tvm.register_func("tilelang_callback_cuda_postproc", override=True)
@register_cuda_postproc_callback
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
......@@ -233,5 +234,4 @@ def test_gemm_jit_kernel():
if __name__ == "__main__":
# tilelang.testing.main()
test_gemm_jit_kernel()
tilelang.testing.main()
from typing import Callable, Union
from tvm import register_func
from tvm.target import Target
def register_cuda_postproc(func: Callable[[str, Target], str], override: bool = True):
"""Register a post-processing function for CUDA code generation.
Args:
func: A callable that takes generated code (str) and target (Target) as input,
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_cuda_postproc", f=func, override=override)
def register_hip_postproc(func: Callable[[str, Target], str], override: bool = True):
"""Register a post-processing function for HIP code generation.
Args:
func: A callable that takes generated code (str) and target (Target) as input,
and returns the processed code (str).
override: Whether to override existing registered function. Defaults to True.
"""
register_func("tilelang_callback_hip_postproc", f=func, override=override)
def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function.
Can be used with or without parentheses:
@register_cuda_postproc_callback
def func(code, target): ...
@register_cuda_postproc_callback()
def func(code, target): ...
@register_cuda_postproc_callback(override=False)
def func(code, target): ...
Args:
func: The function to be decorated or a boolean override flag
override: Whether to override existing registered function. Defaults to True.
"""
if callable(func):
register_cuda_postproc(func, override)
return func
if func is None or isinstance(func, bool):
_override = func if isinstance(func, bool) else override
def _register(fn: Callable[[str, Target], str]):
register_cuda_postproc(fn, _override)
return fn
return _register
raise TypeError("Invalid decorator usage")
def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
"""Decorator for registering HIP post-processing callback function.
Can be used with or without parentheses:
@register_hip_postproc_callback
def func(code, target): ...
@register_hip_postproc_callback()
def func(code, target): ...
@register_hip_postproc_callback(override=False)
def func(code, target): ...
Args:
func: The function to be decorated or a boolean override flag
override: Whether to override existing registered function. Defaults to True.
"""
if callable(func):
register_hip_postproc(func, override)
return func
if func is None or isinstance(func, bool):
_override = func if isinstance(func, bool) else override
def _register(fn: Callable[[str, Target], str]):
register_hip_postproc(fn, _override)
return fn
return _register
raise TypeError("Invalid decorator usage")
......@@ -38,9 +38,11 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
else:
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
# TODO(lei): may need a pass to fuse the if-then-else in the
# pipeline loop when we meet dynamic branch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment