Commit cd9ec62e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Example] Implement Kernel Example cumsum (#258)

* Add GPU kernel for 2D continuous cumulative sum in TileLang example

- Introduced a new example script `example_tilelang_cumsum.py` that generates a GPU kernel for 2D continuous cumulative sum.
- Implemented functions to handle kernel configuration, memory allocation, and inclusive scan operations.
- Added a main execution block to demonstrate the kernel's functionality using PyTorch for tensor operations.
- Enhanced the example with error handling for power-of-two configurations and validation of results against PyTorch's built-in cumulative sum function.

* Refactor TileLang examples and enhance kernel compilation

- Updated `example_tilelang_cumsum.py` to improve GPU kernel generation for 2D continuous cumulative sum, including better parameter handling and error checking.
- Refactored `example_mha_bwd.py` to enhance kernel compilation readability and maintainability.
- Modified `kernel_cache.py` to prevent saving kernels to disk when using the DLPack backend, ensuring proper cache management.
- Added `get_block_bindings` function to `kernel.py` for improved access to block bindings in kernel launch frames.
- Cleaned up import statements in `__init__.py` for better organization and clarity.

* Enhance GPU kernel for 2D continuous cumulative sum in TileLang example

- Added additional spacing for improved readability in `example_tilelang_cumsum.py`.
- Refined kernel structure to enhance clarity and maintainability during GPU kernel generation for cumulative sum operations.
parent c770a58f
import math
from typing import Optional
import torch
import tilelang
import tilelang.language as T
from tilelang.cache import clear_cache
clear_cache()
def _is_power_of_two(n: int):
"""Check if n is a power of 2."""
return n > 0 and (n & (n - 1)) == 0
def gpu_2d_continuous_cumsum(
M: int,
N: int,
ty_len: int = 4,
tx_len: int = 32,
in_dtype: str = "int32",
out_dtype: Optional[str] = None,
):
"""Generate GPU kernel for 2D continuous cumsum, i.e. The cumsum axis is -1
Parameters
----------
M : int
The number of rows of the input tensor
N : int
The number of columns of the input tensor
ty_len : int
The length of thread.y
tx_len : int
The length of thread.x
in_dtype : str
The input data type
out_dtype : Optional[str]
The output data type, if None, it will be the same as in_dtype
Returns
-------
cumsum : PrimFunc
The generated cumsum kernel
"""
out_dtype = out_dtype or in_dtype
# Configuration for GPU kernel
TX = T.int32(tx_len) # thread.x
TY = T.int32(ty_len) # thread.y
thread_elem = N # number of elements in single thread
if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N):
raise ValueError("Configuration of TX, TY, N must be power of 2")
# number of elements to be processed by single warp
warp_elem = T.int32(tx_len * thread_elem)
# number of elements to be processed by single block(SM)
block_elem = T.int32(tx_len * ty_len * thread_elem)
LOG_TX = T.int32(int(math.log2(tx_len)))
LOG_BLOCK_N = T.int32(int(math.log2(tx_len * ty_len * thread_elem)))
@T.macro
def block_inclusive_inside_block(
batch: T.int32,
cur_len: T.int32,
source: T.Buffer,
output: T.Buffer,
tmp_buf: T.Buffer,
src_offset: T.int32,
tmp_offset: T.int32,
):
local_buf = T.alloc_buffer((thread_elem,), out_dtype, scope="local")
shared_buf = T.alloc_buffer((block_elem,), out_dtype, scope="shared")
bx = T.get_block_binding(0)
by = T.get_block_binding(1)
tx = T.get_thread_binding(0)
ty = T.get_thread_binding(1)
tx_idx = bx * block_elem + ty * warp_elem + tx * thread_elem
# Load data from global memory
for i in T.vectorized(N):
local_buf[i] = T.if_then_else(
tx_idx + i < cur_len,
T.Cast(out_dtype, source[by, src_offset + tx_idx + i]),
T.Cast(out_dtype, 0),
)
# Inclusive scan inside thread
for i in T.serial(1, N):
local_buf[i] += local_buf[i - 1]
# Store data to shared memory
for i in T.vectorized(N):
shared_buf[ty * warp_elem + tx * thread_elem + i] = local_buf[i]
# Inclusive scan inside warp
for i in T.serial(LOG_TX):
for j in T.vectorized(N):
idx: T.int32 = ty * warp_elem + tx * thread_elem
if tx >= (1 << i):
shared_buf[idx + j] += shared_buf[idx - (1 << i) * thread_elem + N - 1]
# Inclusive scan inside block
for i in T.serial(1, TY):
for j in T.vectorized(N):
if ty == 0:
idx: T.int32 = i * warp_elem + tx * thread_elem
shared_buf[idx + j] += shared_buf[i * warp_elem - 1]
# Write sum of block to global memory
for i in T.vectorized(N):
idx: T.int32 = ty * warp_elem + tx * thread_elem + i
if bx * block_elem + idx < cur_len:
output[by, src_offset + bx * block_elem + idx] = shared_buf[idx]
if tx == 0 and ty == 0:
for i in T.vectorized(N): # noqa: B007
tmp_buf[by, tmp_offset + bx] = shared_buf[block_elem - 1]
@T.macro
def update_cross_block(
batch: T.int32,
cur_len: T.int32,
source: T.Buffer,
output: T.Buffer,
src_offset: T.int32,
out_offset: T.int32,
):
bx = T.get_block_binding(0)
by = T.get_block_binding(1)
tx = T.get_thread_binding(0)
ty = T.get_thread_binding(1)
for i in T.serial(N):
idx: T.int32 = bx * block_elem + ty * warp_elem + i * TX + tx
if idx < cur_len:
output[by, out_offset + idx] += T.if_then_else(bx > 0,
source[by, src_offset + bx - 1], 0)
@T.prim_func
def cumsum(A: T.Buffer((M, N), dtype="int32"), Out: T.Buffer((M, N), dtype="int32"),
Tmp: T.Buffer((M, N), dtype="int32")):
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", N))))
total_rounds = ceil_log2 // LOG_BLOCK_N
with T.Kernel(T.ceildiv(N, block_elem), M, threads=[tx_len, ty_len]) as (bx, by):
block_inclusive_inside_block(
M, N, A, Out, Tmp, src_offset=T.int32(0), tmp_offset=T.int32(0))
for i in range(total_rounds):
cur_len = T.ceildiv(N, 1 << (LOG_BLOCK_N * (i + 1)))
with T.Kernel(T.ceildiv(cur_len, block_elem), M) as (bx, by):
block_inclusive_inside_block(
M,
cur_len,
Tmp,
Tmp,
Tmp,
src_offset=i * T.ceildiv(N, block_elem),
tmp_offset=(i + 1) * T.ceildiv(N, block_elem),
)
for i in range(total_rounds - 1):
real_idx = total_rounds - 1 - i - 1
cur_len = T.ceildiv(N, 1 << (LOG_BLOCK_N * (real_idx + 1)))
with T.Kernel(T.ceildiv(cur_len, block_elem), M) as (bx, by):
update_cross_block(
M,
cur_len,
Tmp,
Tmp,
src_offset=(real_idx + 1) * T.ceildiv(N, block_elem),
out_offset=real_idx * T.ceildiv(N, block_elem),
)
with T.Kernel(T.ceildiv(N, block_elem), M) as (bx, by):
update_cross_block(M, N, Tmp, Out, src_offset=0, out_offset=0)
return cumsum
def torch_cumsum(A: torch.Tensor, dim: int = -1):
return torch.cumsum(A, dim=dim)
if __name__ == "__main__":
M = 128
N = 32
program = gpu_2d_continuous_cumsum(M, N)
kernel = tilelang.compile(program, execution_backend="dlpack", out_idx=[1])
code = kernel.get_kernel_source()
A = torch.randint(0, 10, (M, N)).cuda().to(torch.int32)
tmp = torch.zeros_like(A).cuda().to(torch.int32)
tilelang_output = kernel(A, tmp)
torch_output = torch_cumsum(A).cuda().to(torch.int32)
torch.testing.assert_close(tilelang_output, torch_output, atol=1e-2, rtol=1e-2)
import torch
import torch.nn.functional as F
import tilelang
from tilelang import cached
from tilelang.autotuner import *
import tilelang.language as T
import argparse
......@@ -233,8 +232,12 @@ class _attention(torch.autograd.Function):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
mod = cached(flashattn_fwd, [3, 4], BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = mod(q, k, v)
kernel = tilelang.compile(
flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N),
out_idx=[3, 4],
target="cuda",
execution_backend="cython")
o, lse = kernel(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
......@@ -251,13 +254,24 @@ class _attention(torch.autograd.Function):
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
mod_prep = cached(flashattn_bwd_preprocess, [2], BATCH, H, N_CTX, D_HEAD)
mod_post = cached(flashattn_bwd_postprocess, [1], BATCH, H, N_CTX, D_HEAD)
delta = mod_prep(o, do)
mod = cached(flashattn_bwd, [6, 7, 8], BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M,
block_N)
dq, dk, dv = mod(q, k, v, do, lse, delta)
dq = mod_post(dq)
kernel_prep = tilelang.compile(
flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD),
out_idx=[2],
target="cuda",
execution_backend="cython")
kernel_post = tilelang.compile(
flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD),
out_idx=[1],
target="cuda",
execution_backend="cython")
delta = kernel_prep(o, do)
kernel = tilelang.compile(
flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N),
out_idx=[6, 7, 8],
target="cuda",
execution_backend="cython")
dq, dk, dv = kernel(q, k, v, do, lse, delta)
dq = kernel_post(dq)
return dq, dk, dv, None
......
......@@ -181,22 +181,23 @@ private:
// T.address_of(C_shared))
Stmt VisitStmt_(const EvaluateNode *op) final {
auto evaluate = Downcast<Evaluate>(StmtExprMutator::VisitStmt_(op));
auto call = Downcast<Call>(evaluate->value);
if (call.defined() && call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.size() == 0) {
return evaluate;
}
Stmt evaluate_with_conditions = evaluate;
for (auto cond : conditions) {
evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions);
if (const CallNode *call_op = op->value.as<CallNode>()) {
auto call = Downcast<Call>(evaluate->value);
if (call->op == builtin::call_extern()) {
GlobalMemChecker checker(analyzer_);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();
if (conditions.size() == 0) {
return evaluate;
}
Stmt evaluate_with_conditions = evaluate;
for (auto cond : conditions) {
evaluate_with_conditions = IfThenElse(cond, evaluate_with_conditions);
}
return evaluate_with_conditions;
}
return evaluate_with_conditions;
}
return evaluate;
......
......@@ -113,7 +113,10 @@ class KernelCache:
pass_configs=pass_configs,
)
self._cache[key] = kernel # Store in in-memory cache
self._save_kernel_to_disk(key, kernel, func)
if execution_backend == "dlpack":
self.logger.warning("DLPack backend does not support cache saving to disk.")
else:
self._save_kernel_to_disk(key, kernel, func)
return kernel
def clear_cache(self):
......
......@@ -8,7 +8,14 @@ from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .kernel import Kernel, KernelLaunchFrame, get_thread_binding # noqa: F401
from .kernel import (
Kernel, # noqa: F401
KernelLaunchFrame, # noqa: F401
get_thread_binding, # noqa: F401
get_thread_bindings, # noqa: F401
get_block_binding, # noqa: F401
get_block_bindings, # noqa: F401
)
from .allocate import (
alloc_local, # noqa: F401
alloc_shared, # noqa: F401
......
......@@ -146,6 +146,19 @@ class KernelLaunchFrame(TIRFrame):
num_threads *= self.get_thread_extent(thread_dim)
return num_threads
def get_block_binding(self, dim: int = 0) -> Var:
"""
Returns the block binding for the given dimension.
dim=0 corresponds to blockIdx.x, dim=1 to blockIdx.y, and dim=2 to blockIdx.z.
"""
return self.frames[dim].iter_var.var
def get_block_bindings(self) -> List[Var]:
"""
Returns all three block bindings.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
def blocks(self) -> List[Var]:
"""
......@@ -230,3 +243,15 @@ def get_thread_bindings() -> List[Var]:
"""Returns all three thread bindings.
"""
return KernelLaunchFrame.Current().get_thread_bindings()
def get_block_binding(dim: int = 0) -> Var:
"""Returns the block binding for the given dimension.
"""
return KernelLaunchFrame.Current().get_block_binding(dim)
def get_block_bindings() -> List[Var]:
"""Returns all three block bindings.
"""
return KernelLaunchFrame.Current().get_block_bindings()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment