"docs/vscode:/vscode.git/clone" did not exist on "8bde6a543ba00adf2f7e330fbfebd624c876ab4d"
Unverified Commit 9708d353 authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Support MHA with chunked prefix cache for flashinfer/flashmla backend, support...


Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616)
Co-authored-by: default avatarxuyongfei.xyf <xuyongfei.xyf@antgroup.com>
parent 704ced1b
......@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
o = result
else:
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
and forward_batch.attn_attend_prefix_cache is not None
forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"]
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
......@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
output, lse, *rest = flash_attn_varlen_func(
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
......@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
)
else:
# MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func(
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
......@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
return_softmax_lse=True,
return_softmax_lse=forward_batch.mha_return_lse,
)
return output, lse
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
......
......@@ -59,6 +59,115 @@ class PrefillMetadata:
global_workspace_buffer = None
class FlashInferMhaChunkKVRunner:
def __init__(
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
# Buffers and wrappers
self.qo_indptr = attn_backend.qo_indptr
self.workspace_buffer = attn_backend.workspace_buffer
self.fmha_backend = attn_backend.fmha_backend
self.chunk_ragged_wrappers = []
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
def update_prefix_chunks(self, num_prefix_chunks: int):
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
self.chunk_ragged_wrappers.append(ragged_wrapper)
def update_wrapper(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.num_prefix_chunks is not None
num_prefix_chunks = forward_batch.num_prefix_chunks
self.update_prefix_chunks(num_prefix_chunks)
prefix_lens = forward_batch.extend_prefix_lens
seq_lens = forward_batch.seq_lens
bs = len(seq_lens)
qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
for chunk_idx in range(forward_batch.num_prefix_chunks):
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
wrapper = self.chunk_ragged_wrappers[chunk_idx]
wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=False,
)
# ragged prefill
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
logits_soft_cap = layer.logit_cap
if forward_batch.attn_attend_prefix_cache:
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
wrapper = self.chunk_ragged_wrappers[chunk_idx]
o1, s1 = wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.ragged_wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
return o1, s1
class FlashInferMLAAttnBackend(AttentionBackend):
"""Flashinfer attention kernels."""
......@@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
)
self.page_size = model_runner.page_size
# Allocate buffers
......@@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else:
self.q_indptr_decode = q_indptr_decode_buf
fmha_backend = "auto"
self.fmha_backend = "auto"
if is_sm100_supported():
fmha_backend = "cutlass"
self.fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=fmha_backend
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
if not self.skip_prefill:
......@@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self
)
if self.enable_chunk_kv:
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self
......@@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
def forward_extend(
self,
q: torch.Tensor,
......@@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
): # MHA Chunk
assert self.enable_chunk_kv
assert q_rope is None
assert k_rope is None
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
return o1, s1
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
......@@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k = torch.cat([k, k_rope], dim=-1)
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
......@@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
else:
# mla paged prefill
......
......@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_symm_mem",
"quantization",
"enable_custom_logit_processor",
"disaggregation_mode",
]
# Put some global args for easy access
......
......@@ -241,6 +241,9 @@ class ForwardBatch:
prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether lse needs to be returned
mha_return_lse: Optional[bool] = None
# For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None
......
......@@ -518,9 +518,6 @@ class ModelRunner:
if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.")
......
......@@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module):
if attention_backend == "ascend":
return AttnForwardMethod.MLA
elif attention_backend == "flashinfer":
elif (
attention_backend == "flashinfer"
or attention_backend == "fa3"
or attention_backend == "flashmla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
# Flashinfer MLA: Do not absorb when enabling ragged prefill
disable_ragged = (
attention_backend == "flashinfer" or attention_backend == "flashmla"
) and self.flashinfer_mla_disable_ragged
if (
not self.flashinfer_mla_disable_ragged
not disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
):
return AttnForwardMethod.MHA
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None:
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
if (
forward_batch.forward_mode.is_extend()
and not self.disable_chunked_prefix_cache
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and (
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
(
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
and not self.disable_chunked_prefix_cache
)
or sum_extend_prefix_lens == 0
)
):
......@@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module):
k[..., self.qk_nope_head_dim :] = k_pe
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
tmp_output = torch.empty_like(accum_output)
tmp_lse = torch.empty_like(accum_lse)
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
......@@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module):
# will be helpful for understanding the purpose of this function.
# First do normal mha forward to get output for extended part
if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
return self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
return q, k, v, forward_batch
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
forward_batch.mha_return_lse = has_extend_prefix
# Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
# Do mha attention with chunked prefix cache if there are any sequence with prefix
if any(forward_batch.extend_prefix_lens_cpu):
# Only initialize the info once
if forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
if has_extend_prefix:
attn_output, lse = attn_output
forward_batch.set_attn_attend_prefix_cache(True)
attn_output = self._chunked_prefix_attn_mha(
q=q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment