Unverified Commit 799c4bb5 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fuse MLA set kv cache kernel (#5748)

parent 02723e1b
......@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
......@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
v,
k_rope,
)
# Use precomputed metadata across all layers
......@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
......@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
v,
k_rope,
)
# Use precomputed metadata across all layers
......
......@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if "k_rope" not in kwargs:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q,
......
......@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
import torch
import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend
......@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2[loc] = src_2.to(dtype).view(store_dtype)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
if base + BLOCK <= nope_dim:
src = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs,
mask=mask,
)
else:
offs_rope = offs - nope_dim
src = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
mask=mask,
)
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
class MLATokenToKVPool(KVCache):
def __init__(
self,
......@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
else:
self.kv_buffer[layer_id][loc] = cache_k
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
)
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
......
......@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k = torch.cat([k_nope, k_pe], dim=-1)
if self.attention_backend == "fa3":
attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment