"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "a711c80e024e3950982ac9e58e2244057db6aefb"
Unverified Commit ac8c9afc authored by Zhichen Zeng's avatar Zhichen Zeng Committed by GitHub
Browse files

[Example] Add sparse mla bwd example for deepseek_v32 (#919)



* Add sparse mla bwd example

* add bwd into test

* Update README with bwd impl

* comment

* format fix

* lint fix

* fwd fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 481cae42
...@@ -6,6 +6,7 @@ deepseek_v32/ ...@@ -6,6 +6,7 @@ deepseek_v32/
├── figures/ # Figures and diagrams ├── figures/ # Figures and diagrams
├── inference/ # Inference implementation folder ├── inference/ # Inference implementation folder
├── fp8_lighting_indexer.py # FP8 lighting indexer ├── fp8_lighting_indexer.py # FP8 lighting indexer
├── sparse_mla_bwd.py # Sparse MLA backward implementation
├── sparse_mla_fwd.py # Sparse MLA forward implementation ├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass ├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
├── topk_selector.py # Top-k selector implementation ├── topk_selector.py # Top-k selector implementation
...@@ -21,7 +22,7 @@ The architecture diagram above highlights three key components (shown in green) ...@@ -21,7 +22,7 @@ The architecture diagram above highlights three key components (shown in green)
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision 1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation 2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass 3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
### Lightning Indexer ### Lightning Indexer
...@@ -166,3 +167,57 @@ for i_i in T.serial(T.ceildiv(NI, 2)): ...@@ -166,3 +167,57 @@ for i_i in T.serial(T.ceildiv(NI, 2)):
``` ```
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul. Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
### Sparse MLA Backward
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
The backward pass consists of three main stages:
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
```python
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
```
**2. Main Backward Computation**: Computes gradients through sparse attention:
```python
# Sparse MLA backward: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
# Load KV data for selected indices
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i]
# Recompute attention scores for backward
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
# Apply softmax gradient: dP = P * (dP_raw - Delta)
for h_i, bi_i in T.Parallel(padded_H, BI):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
```
The key gradient computations are:
- **dQ = dP @ K** (query gradients)
- **dK = dP^T @ Q** (key gradients)
- **dV = P^T @ dO** (value gradients)
**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation:
```python
# Atomically update dKV at selected indices
for bi_i, d_i in T.Parallel(BI // split_store, D // 4):
T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4])
```
**Performance**: The sparse MLA backward achieves excellent performance:
- **H800 SXM**: ~100 TFlops
- **H200 SXM**: ~115 TFlops
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from utils import assert_tensors_similar
@tilelang.jit(out_idx=[-1])
def preprocess(
B,
S,
H,
D,
block_ND=32,
num_stages=5,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
shape = [B, S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([B, S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
delta = T.alloc_fragment([block_ND], accum_dtype)
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(
O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
o)
T.copy(
dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx])
return preprocess_kernel
@tilelang.jit(out_idx=[-1])
def postprocess(
B,
S_kv,
D,
D_tail,
kv_group=1,
block_N=64,
threads=128,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
dkv_shape = [B, S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz):
T.copy(
dKV[bz, bx * block_N:(bx + 1) * block_N, by, :],
dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :],
)
return postprocess_kernel
@tilelang.jit(
out_idx=[-2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def bwd(
B,
S,
S_kv,
H,
D,
D_tail,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
block_size=32,
num_stages=0,
threads=256,
indices_dtype="int32",
dtype="bfloat16",
accum_dtype="float",
):
assert is_causal == True, 'non-casual is not supported now'
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5)
sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e)
H_kv = H // kv_group
q_shape = [B, S, H, D + D_tail]
k_shape = [B, S_kv, kv_group, D + D_tail]
o_shape = [B, S, H, D]
indices_shape = [B, S, kv_group, topk]
delta_shape = [B, S, H]
lse_shape = [B, S, H]
assert indices_dtype == "int32"
assert dtype == "bfloat16"
assert accum_dtype == "float"
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
BS = block_size
NS = tilelang.cdiv(topk, block_size)
split_store = 2
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
KV_shared = T.alloc_shared([BS, D], dtype)
KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
dO_shared = T.alloc_shared([padded_H, D], dtype)
mask = T.alloc_fragment([BS], "bool")
P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dQ_shared = T.alloc_shared([padded_H, D], dtype)
dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
max_kv_i = s_i
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout({
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
})
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz,
D + d_i]
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 -
Lse[by, s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (
acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
dP_shared_cast,
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
if bi_i < BS // split_store:
acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i,
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4])
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4])
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q,
kv,
o,
do,
indices,
lse,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
assert lse.is_contiguous()
B, S, H, dim_plus_tail_dim = q.shape
_, S_kv, kv_group, _ = kv.shape
assert kv.shape[-1] == dim_plus_tail_dim
assert kv.shape[0] == B
# dim should be assigned
D = 512
D_tail = dim_plus_tail_dim - D
topk = indices.shape[-1]
assert indices.shape == (B, S, kv_group, topk)
assert lse.shape == (B, S, H)
# Get kernels
preprocess_kernel = preprocess(B, S, H, D)
bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual)
postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group)
if delta is None:
delta = preprocess_kernel(o, do)
dkv = torch.zeros_like(kv, dtype=torch.float32)
dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv)
dkv = postprocess_kernel(dkv)
return dq, dkv
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
kv.requires_grad = True
o = ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale, is_casual)
o.backward(do)
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1,
S=4096,
SKV=32768,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16):
# Prepare data
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda')
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda')
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None)
if SKV <= 4096:
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum([
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
])
from tilelang.profiler import do_bench
def fn():
return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ',
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(
B=1, S=4096, SKV=4096, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16)
...@@ -5,6 +5,7 @@ from topk_selector import test_topk_selector ...@@ -5,6 +5,7 @@ from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd
def test_example_topk_selector(): def test_example_topk_selector():
...@@ -29,5 +30,11 @@ def test_example_sparse_mla_fwd_pipelined(): ...@@ -29,5 +30,11 @@ def test_example_sparse_mla_fwd_pipelined():
test_sparse_mla_fwd_pipelined(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256) test_sparse_mla_fwd_pipelined(S=1024, SKV=2048, H=128, HKV=1, DQK=576, DV=512, topk=256)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd()
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment