Commit 13f4b5c6 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Update GEMM FP8 Example (#123)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix

* nas kernel

* Enhance Native Sparse Attention Examples with Code Improvements and Parameter Updates

- Updated example_tilelang_nsa.py and example_triton_nsa.py with code formatting and style improvements
- Increased default number of heads and selected blocks in TileLang NSA example
- Added Ruff linter ignore comments to reference.py
- Standardized function signatures and improved code readability across NSA implementations

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Add utility math functions for integer operations

- Implement `next_power_of_2()` to calculate the next power of 2 for an integer
- Add `cdiv()` function for ceiling division of integers

* Refactor DeepSeek MLA Decode Example with Enhanced Flash Attention Implementation

- Update flash attention kernel to support positional embeddings (PE)
- Modify reference implementation to handle PE and group query attention
- Increase default batch size and adjust benchmarking parameters
- Improve kernel performance and readability
- Add einops and torch operations for more flexible tensor manipulation

* Update README.md with corrected Flash MLA Decoding example path

- Modify the example link for Flash MLA Decoding to point to the correct directory
- Ensure accurate navigation to the DeepSeek MLA decoding example
parent f1fcfe34
...@@ -26,7 +26,7 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp ...@@ -26,7 +26,7 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp
- [Dequantization GEMM](./examples/dequantize_gemm/) - [Dequantization GEMM](./examples/dequantize_gemm/)
- [Flash Attention](./examples/flash_attention/) - [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/) - [Flash Linear Attention](./examples/linear_attention/)
- [Flash MLA Decoding](./examples/flash_decoding/example_mla_decode.py) - [Flash MLA Decoding](./examples/deepseek_mla/)
- [Native Sparse Attention](./examples/native_sparse_attention/) - [Native Sparse Attention](./examples/native_sparse_attention/)
Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added. Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.
......
...@@ -3,17 +3,13 @@ import torch.nn.functional as F ...@@ -3,17 +3,13 @@ import torch.nn.functional as F
import tilelang import tilelang
from tilelang.autotuner import * from tilelang.autotuner import *
import tilelang.language as T import tilelang.language as T
from einops import rearrange, einsum
num_split = 4 num_split = 1
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, (dim + pe_dim)]
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
shape_v = [batch, seqlen_kv, kv_head_num, dim]
shape_o = [batch, heads, dim]
part_shape = [batch, heads, num_split, dim]
dtype = "float16" dtype = "float16"
accum_dtype = "float" accum_dtype = "float"
kv_group_num = heads // kv_head_num kv_group_num = heads // kv_head_num
...@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer([batch, heads, dim], dtype),
K: T.Buffer(shape_k, dtype), Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
V: T.Buffer(shape_v, dtype), KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
): ):
with T.Kernel( with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz): batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype) S_shared = T.alloc_shared([block_H, block_N], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype) O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype) acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype) scores_max = T.alloc_fragment([block_H], accum_dtype)
...@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -53,20 +53,32 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
}) })
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=1): for k in T.Pipelined(loop_range, num_stages=2):
T.copy( kv_start = (seqlen_kv // num_split) * sid + k * block_N
K[bid, (seqlen_kv // num_split) * sid + kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
cur_kv_head, :], K_shared) T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.clear(acc_s) T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.clear(acc_s_0)
T.gemm(
Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s_0,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(acc_s_0, S_shared)
T.copy(S_shared, acc_s)
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
...@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -78,11 +90,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
V[bid, (seqlen_kv // num_split) * sid +
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
...@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -96,8 +104,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro @T.macro
def combine( def combine(
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Buffer([batch, heads, dim], dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
...@@ -133,50 +141,61 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -133,50 +141,61 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func @T.prim_func
def main( def main(
Q: T.Buffer(shape_q, dtype), Q: T.Buffer([batch, heads, dim], dtype),
K: T.Buffer(shape_k, dtype), Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
V: T.Buffer(shape_v, dtype), KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype), glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim] Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer(shape_o, dtype), Output: T.Buffer([batch, heads, dim], dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
return main return main
def ref_program(query, key, value, glse, Output_partial): def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """ # """
# Inputs: # Inputs:
# - query (Tensor): [batch, heads, dim] # - q (Tensor): [batch, heads, dim]
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim] # - q_pe (Tensor): [batch, heads, pe_dim]
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim] # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs: # Outputs:
# - output (Tensor): [batch, heads, dim] # - output (Tensor): [batch, heads, dim]
# """ # """
from einops import rearrange dim = q.shape[-1]
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim] pe_dim = q_pe.shape[-1]
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim] num_head_groups = q.shape[1] // kv.shape[2]
dim_v = value.shape[-1] scale = (dim + pe_dim)**0.5
assert kv_heads == 1, "kv_heads must be 1" q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim]
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim] q_pe = rearrange(
value_expanded = value.expand(-1, -1, query_heads, q_pe, 'b (h g) d -> b g h d',
-1) # [batch_size, query_heads, seqlen_kv, dim] g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
key_expanded = rearrange(key_expanded,
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim] kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
value_expanded = rearrange(value_expanded,
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim] k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
scores = torch.matmul(query_expanded, query = torch.concat([q, q_pe], dim=-1)
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv] key = torch.concat([kv, k_pe], dim=-1)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv] scores = einsum(
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim] query, key,
return output.view(batch_size, query_heads, dim_v) 'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def flash_split_ref(Q, K, V): def flash_split_ref(Q, K, V):
...@@ -251,7 +270,7 @@ def reduce_ref(Q, K, V, glse, Output_partial): ...@@ -251,7 +270,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
if __name__ == "__main__": if __name__ == "__main__":
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64 BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 128, 128, 1, 8192, 512, 64
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE) qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
...@@ -260,8 +279,9 @@ if __name__ == "__main__": ...@@ -260,8 +279,9 @@ if __name__ == "__main__":
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H) program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
mod, params = tilelang.lower(program) mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = mod.do_bench(mod.func, warmup=500) print("All close")
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
num_split = 1
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
@T.macro
def flash_attn_split(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
bid = bx
hid = by
sid = bz
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * sid + k * block_N
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
T.copy(
KV[bid, kv_start:kv_end, cur_kv_head, :],
KV_shared
)
T.copy(
K_pe[bid, kv_start:kv_end, cur_kv_head, :],
K_pe_shared
)
T.clear(acc_s_0)
T.gemm(Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.copy(acc_s_0, S_shared)
T.copy(S_shared, acc_s)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
sid, :])
@T.macro
def combine(
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 1], dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
for k in T.Parallel(num_split):
lse_local[k, 0] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main(
Q: T.Buffer([batch, heads, dim], dtype),
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Buffer([batch, heads, num_split], dtype),
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
Output: T.Buffer([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
return main
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim) ** 0.5
q = rearrange(
q, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def flash_split_ref(Q, K, V):
dim = 512
pe_dim = 64
batch = Q.size(0)
nheads = Q.size(1)
assert Q.size(2) == dim + pe_dim, "dim must be 576=512+64"
block_N = 32
seqlen_kv = K.size(1)
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale
K_ = K.expand(-1, -1, nheads, -1)
V_ = V.expand(-1, -1, nheads, -1)
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bhd,bkhd->bhk', Q_,
K_[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
acc_o *= scores_scale[:, :, None]
acc_s = torch.exp2(acc_s - scores_max[:, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
'bhk,bkhd->bhd', acc_s_cast,
V_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, None]
logsum = torch.log2(logsum) + scores_max
gacc_o[ks, :, :, :] = acc_o
glogsum[ks, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
def reduce_ref(Q, K, V, glse, Output_partial):
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0)
lse_max = glse.max(dim=2, keepdim=False).values
for ks in range(num_split):
lse = glse[:, :, ks]
lse_logsum += torch.exp2(lse - lse_max)
lse_logsum = torch.log2(lse_logsum) + lse_max
for ks in range(num_split):
lse = glse[:, :, ks]
scale = torch.exp2(lse - lse_logsum)
o += Output_partial[:, :, ks, :] * scale[:, :, None]
return o.to(torch.float16)
if __name__ == "__main__":
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 128, 128, 1, 8192, 512, 64
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
total_flops = qk_flops + pv_flops
BLOCK_N = 32 # if D_HEAD <= 128 else 32
BLOCK_H = 64
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All close")
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype)
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment