Commit 3ca3a8af authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Include several examples into ci (#531)

* Remove unused 2D continuous cumulative sum example and related functions from the cumsum module.

* lint fix

* fix split k example

* Enable cache disabling in gemm_streamk example and add validation checks in if_stmt_binding transformation

* Update gemm_streamk example to use tilelang's cdiv function for block calculations and add copyright notice
parent 5b9015a3
import math
from typing import Optional
import torch
import tilelang
import tilelang.language as T
from tilelang.cache import clear_cache
clear_cache()
def _is_power_of_two(n: int):
"""Check if n is a power of 2."""
return n > 0 and (n & (n - 1)) == 0
def gpu_2d_continuous_cumsum(
M: int,
N: int,
ty_len: int = 4,
tx_len: int = 32,
in_dtype: str = "int32",
out_dtype: Optional[str] = None,
):
"""Generate GPU kernel for 2D continuous cumsum, i.e. The cumsum axis is -1
Parameters
----------
M : int
The number of rows of the input tensor
N : int
The number of columns of the input tensor
ty_len : int
The length of thread.y
tx_len : int
The length of thread.x
in_dtype : str
The input data type
out_dtype : Optional[str]
The output data type, if None, it will be the same as in_dtype
Returns
-------
cumsum : PrimFunc
The generated cumsum kernel
"""
out_dtype = out_dtype or in_dtype
# Configuration for GPU kernel
TX = T.int32(tx_len) # thread.x
TY = T.int32(ty_len) # thread.y
thread_elem = N # number of elements in single thread
if not _is_power_of_two(TX) or not _is_power_of_two(TY) or not _is_power_of_two(N):
raise ValueError("Configuration of TX, TY, N must be power of 2")
# number of elements to be processed by single warp
warp_elem = T.int32(tx_len * thread_elem)
# number of elements to be processed by single block(SM)
block_elem = T.int32(tx_len * ty_len * thread_elem)
LOG_TX = T.int32(int(math.log2(tx_len)))
LOG_BLOCK_N = T.int32(int(math.log2(tx_len * ty_len * thread_elem)))
@T.macro
def block_inclusive_inside_block(
batch: T.int32,
cur_len: T.int32,
source: T.Tensor,
output: T.Tensor,
tmp_buf: T.Tensor,
src_offset: T.int32,
tmp_offset: T.int32,
):
local_buf = T.alloc_buffer((thread_elem,), out_dtype, scope="local")
shared_buf = T.alloc_buffer((block_elem,), out_dtype, scope="shared")
bx = T.get_block_binding(0)
by = T.get_block_binding(1)
tx = T.get_thread_binding(0)
ty = T.get_thread_binding(1)
tx_idx = bx * block_elem + ty * warp_elem + tx * thread_elem
# Load data from global memory
for i in T.vectorized(N):
local_buf[i] = T.if_then_else(
tx_idx + i < cur_len,
T.Cast(out_dtype, source[by, src_offset + tx_idx + i]),
T.Cast(out_dtype, 0),
)
# Inclusive scan inside thread
for i in T.serial(1, N):
local_buf[i] += local_buf[i - 1]
# Store data to shared memory
for i in T.vectorized(N):
shared_buf[ty * warp_elem + tx * thread_elem + i] = local_buf[i]
# Inclusive scan inside warp
for i in T.serial(LOG_TX):
for j in T.vectorized(N):
idx: T.int32 = ty * warp_elem + tx * thread_elem
if tx >= (1 << i):
shared_buf[idx + j] += shared_buf[idx - (1 << i) * thread_elem + N - 1]
# Inclusive scan inside block
for i in T.serial(1, TY):
for j in T.vectorized(N):
if ty == 0:
idx: T.int32 = i * warp_elem + tx * thread_elem
shared_buf[idx + j] += shared_buf[i * warp_elem - 1]
# Write sum of block to global memory
for i in T.vectorized(N):
idx: T.int32 = ty * warp_elem + tx * thread_elem + i
if bx * block_elem + idx < cur_len:
output[by, src_offset + bx * block_elem + idx] = shared_buf[idx]
if tx == 0 and ty == 0:
for i in T.vectorized(N): # noqa: B007
tmp_buf[by, tmp_offset + bx] = shared_buf[block_elem - 1]
@T.macro
def update_cross_block(
batch: T.int32,
cur_len: T.int32,
source: T.Tensor,
output: T.Tensor,
src_offset: T.int32,
out_offset: T.int32,
):
bx = T.get_block_binding(0)
by = T.get_block_binding(1)
tx = T.get_thread_binding(0)
ty = T.get_thread_binding(1)
for i in T.serial(N):
idx: T.int32 = bx * block_elem + ty * warp_elem + i * TX + tx
if idx < cur_len:
output[by, out_offset + idx] += T.if_then_else(bx > 0,
source[by, src_offset + bx - 1], 0)
@T.prim_func
def cumsum(A: T.Tensor((M, N), dtype="int32"), Out: T.Tensor((M, N), dtype="int32"),
Tmp: T.Tensor((M, N), dtype="int32")):
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", N))))
total_rounds = ceil_log2 // LOG_BLOCK_N
with T.Kernel(T.ceildiv(N, block_elem), M, threads=[tx_len, ty_len]) as (bx, by):
block_inclusive_inside_block(
M, N, A, Out, Tmp, src_offset=T.int32(0), tmp_offset=T.int32(0))
for i in range(total_rounds):
cur_len = T.ceildiv(N, 1 << (LOG_BLOCK_N * (i + 1)))
with T.Kernel(T.ceildiv(cur_len, block_elem), M) as (bx, by):
block_inclusive_inside_block(
M,
cur_len,
Tmp,
Tmp,
Tmp,
src_offset=i * T.ceildiv(N, block_elem),
tmp_offset=(i + 1) * T.ceildiv(N, block_elem),
)
for i in range(total_rounds - 1):
real_idx = total_rounds - 1 - i - 1
cur_len = T.ceildiv(N, 1 << (LOG_BLOCK_N * (real_idx + 1)))
with T.Kernel(T.ceildiv(cur_len, block_elem), M) as (bx, by):
update_cross_block(
M,
cur_len,
Tmp,
Tmp,
src_offset=(real_idx + 1) * T.ceildiv(N, block_elem),
out_offset=real_idx * T.ceildiv(N, block_elem),
)
with T.Kernel(T.ceildiv(N, block_elem), M) as (bx, by):
update_cross_block(M, N, Tmp, Out, src_offset=0, out_offset=0)
return cumsum
def torch_cumsum(A: torch.Tensor, dim: int = -1):
return torch.cumsum(A, dim=dim)
if __name__ == "__main__":
M = 128
N = 32
program = gpu_2d_continuous_cumsum(M, N)
kernel = tilelang.compile(program, execution_backend="dlpack", out_idx=[1])
code = kernel.get_kernel_source()
A = torch.randint(0, 10, (M, N)).cuda().to(torch.int32)
tmp = torch.zeros_like(A).cuda().to(torch.int32)
tilelang_output = kernel(A, tmp)
torch_output = torch_cumsum(A).cuda().to(torch.int32)
torch.testing.assert_close(tilelang_output, torch_output, atol=1e-2, rtol=1e-2)
...@@ -178,6 +178,10 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp ...@@ -178,6 +178,10 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp
print(f"tflops: {tflops}") print(f"tflops: {tflops}")
def main():
assert_tl_gemm_correctness(1024, 1024, 8192, 128, "e4m3_float8", "bfloat16", "float32")
if __name__ == "__main__": if __name__ == "__main__":
for dtype in ["e4m3_float8"]: for dtype in ["e4m3_float8"]:
for out_dtype in ["bfloat16", "float32"]: for out_dtype in ["bfloat16", "float32"]:
......
import tilelang.testing
from example_deepgemm_fp8_2xAcc import main
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_deepgemm_fp8_2xAcc():
main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -126,7 +126,7 @@ def native_sparse_attention( ...@@ -126,7 +126,7 @@ def native_sparse_attention(
return native_sparse_attention return native_sparse_attention
if __name__ == "__main__": def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16
groups = HQ // H groups = HQ // H
SEQ_LEN_Q = 1 SEQ_LEN_Q = 1
...@@ -170,3 +170,7 @@ if __name__ == "__main__": ...@@ -170,3 +170,7 @@ if __name__ == "__main__":
print("out", out) print("out", out)
print("ref", ref) print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
...@@ -125,7 +125,7 @@ def native_sparse_attention(batch, ...@@ -125,7 +125,7 @@ def native_sparse_attention(batch,
return native_sparse_attention return native_sparse_attention
if __name__ == "__main__": def main():
B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1
program = native_sparse_attention( program = native_sparse_attention(
...@@ -175,3 +175,7 @@ if __name__ == "__main__": ...@@ -175,3 +175,7 @@ if __name__ == "__main__":
print("out", out) print("out", out)
print("ref", ref) print("ref", ref)
torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main()
# ruff: noqa
import tilelang.testing
from example_tilelang_nsa_fwd import main as main_fwd
from example_tilelang_nsa_decode import main as main_fwd_decode
def test_example_tilelang_nsa_fwd():
main_fwd()
def test_example_tilelang_nsa_fwd_decode():
main_fwd_decode()
if __name__ == "__main__":
tilelang.testing.main()
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tvm import DataType
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"): def matmul(M,
N,
K,
block_M,
block_N,
block_K,
split_k,
dtype="float16",
accum_dtype="float",
out_dtype="float32"):
splitK = K // split_k splitK = K // split_k
...@@ -11,21 +19,15 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d ...@@ -11,21 +19,15 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
if bz == 0:
# fuse the zero initialization kernel
for i, j in T.Parallel(block_M, block_N):
m, n = by * block_M + i, bx * block_N + j
C[m, n] = T.cast(0, dtype)
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
...@@ -34,34 +36,45 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d ...@@ -34,34 +36,45 @@ def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_d
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
if DataType(dtype).bits == 16: # TODO: Automatically add vectorized atomic with enhancement
for i, j in T.Parallel(block_M, block_N // 2): # https://github.com/tile-ai/tilelang/issues/523
m, n = by * block_M + i, bx * block_N + j * 2 # if DataType(dtype).bits == 16:
# vectorized atomic # for i, j in T.Parallel(block_M, block_N // 2):
T.atomic_addx2(C[m, n], C_shared[i, j * 2]) # m, n = by * block_M + i, bx * block_N + j * 2
else: # # vectorized atomic
for i, j in T.Parallel(block_M, block_N): # T.atomic_addx2(C[m, n], C_shared[i, j * 2])
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
for i, j in T.Parallel(block_M, block_N):
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
return main return main
program = matmul(1024, 1024, 1024, 128, 128, 32, 4) def main():
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
split_k = 4
program = matmul(M, N, K, block_M, block_N, block_K, split_k)
kernel = tilelang.compile(program) kernel = tilelang.compile(program)
print(kernel.get_kernel_source()) import torch
import torch torch.random.manual_seed(42)
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
c = torch.zeros(M, N).cuda().float()
kernel(a, b, c)
a = torch.randn(1024, 1024).cuda().half() ref_c = a @ b
b = torch.randn(1024, 1024).cuda().half()
c = torch.zeros(1024, 1024).cuda().half()
kernel(a, b, c)
ref_c = a @ b torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
print(c)
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) if __name__ == "__main__":
main()
import tilelang.testing
from example_tilelang_gemm_splitk import main
def test_example_tilelang_gemm_splitk():
main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -32,9 +32,9 @@ M, K = A.shape ...@@ -32,9 +32,9 @@ M, K = A.shape
N, K = B.shape N, K = B.shape
# accumulator types # accumulator types
# compute grid (work to do per SM on the first wave) # compute grid (work to do per SM on the first wave)
num_block_m = cdiv(M, BLOCK_SIZE_M) num_block_m = tilelang.cdiv(M, BLOCK_SIZE_M)
num_block_n = cdiv(N, BLOCK_SIZE_N) num_block_n = tilelang.cdiv(N, BLOCK_SIZE_N)
iters_per_tile = cdiv(K, BLOCK_SIZE_K) iters_per_tile = tilelang.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_block_m * num_block_n total_tiles = num_block_m * num_block_n
# Two-tile SK + DP # Two-tile SK + DP
...@@ -169,32 +169,37 @@ def tl_matmul_streamk( ...@@ -169,32 +169,37 @@ def tl_matmul_streamk(
return main return main
_tl_matmul_streamk = tl_matmul_streamk( def main():
m, _tl_matmul_streamk = tl_matmul_streamk(
n, m,
k, n,
streamk_tiles, k,
BLOCK_SIZE_M, streamk_tiles,
BLOCK_SIZE_N, BLOCK_SIZE_M,
BLOCK_SIZE_K, BLOCK_SIZE_N,
False, BLOCK_SIZE_K,
True, False,
"float16", True,
"float16", "float16",
"float32", "float16",
2, "float32",
64, 2,
) 64,
)
kernel = tilelang.compile(_tl_matmul_streamk)
print(kernel.get_kernel_source()) kernel = tilelang.compile(_tl_matmul_streamk)
print(kernel.get_kernel_source())
b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)
b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16)
kernel(A, B, b_c)
kernel(A, B, b_c)
C = torch.matmul(A, B.T)
C = torch.matmul(A, B.T)
print(b_c)
print(C) print(b_c)
torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) print(C)
torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2)
if __name__ == "__main__":
main()
import tilelang.testing
from example_tilelang_gemm_streamk import main
def test_example_tilelang_gemm_streamk():
main()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -32,8 +32,10 @@ private: ...@@ -32,8 +32,10 @@ private:
auto then_case = VisitStmt(op->then_case); auto then_case = VisitStmt(op->then_case);
Optional<Stmt> else_case = op->else_case; Optional<Stmt> else_case = op->else_case;
if (else_case.defined()) { if (else_case.defined()) {
else_case = VisitStmt(else_case.value()); return GetRef<Stmt>(op);
} }
ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined";
auto bind_if_stmt = [](Optional<Stmt> body, auto bind_if_stmt = [](Optional<Stmt> body,
const PrimExpr condition) -> Stmt { const PrimExpr condition) -> Stmt {
...@@ -58,9 +60,6 @@ private: ...@@ -58,9 +60,6 @@ private:
if (then_case.defined()) { if (then_case.defined()) {
new_seq.push_back(bind_if_stmt(then_case, condition)); new_seq.push_back(bind_if_stmt(then_case, condition));
} }
if (else_case.defined()) {
new_seq.push_back(bind_if_stmt(else_case, !condition));
}
return new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq)); return new_seq.size() == 1 ? new_seq[0] : SeqStmt(std::move(new_seq));
} }
......
...@@ -133,6 +133,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -133,6 +133,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("global")(mod)
# Global Barrier Synchronization must be applied before # Global Barrier Synchronization must be applied before
# SplitHostDevice pass, as the global barrier # SplitHostDevice pass, as the global barrier
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment