Unverified Commit 799c4bb5 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fuse MLA set kv cache kernel (#5748)

parent 02723e1b
...@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True, save_kv_cache=True,
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
): ):
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
else: else:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
cache_loc, cache_loc,
k, k,
v, k_rope,
) )
# Use precomputed metadata across all layers # Use precomputed metadata across all layers
...@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True, save_kv_cache=True,
# For multi-head latent attention # For multi-head latent attention
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, layer.k_scale, layer.v_scale
) )
else: else:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
cache_loc, cache_loc,
k, k,
v, k_rope,
) )
# Use precomputed metadata across all layers # Use precomputed metadata across all layers
......
...@@ -92,8 +92,11 @@ class RadixAttention(nn.Module): ...@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
if k is not None: if k is not None:
# For cross-layer sharing, kv can be None # For cross-layer sharing, kv can be None
assert v is not None assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) if "k_rope" not in kwargs:
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward( return forward_batch.attn_backend.forward(
q, q,
......
...@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union ...@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import psutil import psutil
import torch import torch
import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend from sglang.srt.utils import debug_timing, get_compiler_backend
...@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): ...@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2[loc] = src_2.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
if base + BLOCK <= nope_dim:
src = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs,
mask=mask,
)
else:
offs_rope = offs - nope_dim
src = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
mask=mask,
)
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
class MLATokenToKVPool(KVCache): class MLATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
...@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache): ...@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
else: else:
self.kv_buffer[layer_id][loc] = cache_k self.kv_buffer[layer_id][loc] = cache_k
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
)
def get_flat_data(self, indices): def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer # prepare a large chunk of contiguous data for efficient transfer
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
......
...@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k = torch.cat([k_nope, k_pe], dim=-1)
if self.attention_backend == "fa3": if self.attention_backend == "fa3":
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
) )
else: else:
q = torch.cat([q_nope_out, q_pe], dim=-1) q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment