Unverified Commit 9708d353 authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Support MHA with chunked prefix cache for flashinfer/flashmla backend, support...


Support MHA with chunked prefix cache for flashinfer/flashmla backend, support page size > 1 for MHA chunked prefix (#8616)
Co-authored-by: default avatarxuyongfei.xyf <xuyongfei.xyf@antgroup.com>
parent 704ced1b
...@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -776,14 +776,13 @@ class FlashAttentionBackend(AttentionBackend):
o = result o = result
else: else:
if ( if (
not global_server_args_dict["disable_chunked_prefix_cache"] forward_batch.attn_attend_prefix_cache is not None
and forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
): ):
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"]
# MHA for chunked prefix kv cache when running model with MLA # MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None
...@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -792,7 +791,8 @@ class FlashAttentionBackend(AttentionBackend):
chunk_idx = forward_batch.prefix_chunk_idx chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0 assert chunk_idx >= 0
output, lse, *rest = flash_attn_varlen_func( assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
...@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -806,7 +806,7 @@ class FlashAttentionBackend(AttentionBackend):
) )
else: else:
# MHA for extend part of sequence without attending prefix kv cache # MHA for extend part of sequence without attending prefix kv cache
output, lse, *rest = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
...@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -816,9 +816,13 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k=metadata.max_seq_len_q, max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
return_softmax_lse=True, return_softmax_lse=forward_batch.mha_return_lse,
) )
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse return output, lse
return output
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(
......
...@@ -59,6 +59,115 @@ class PrefillMetadata: ...@@ -59,6 +59,115 @@ class PrefillMetadata:
global_workspace_buffer = None global_workspace_buffer = None
class FlashInferMhaChunkKVRunner:
def __init__(
self, model_runner: ModelRunner, attn_backend: "FlashInferMlaAttnBackend"
):
# Parse Constants
self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.data_type = model_runner.dtype
self.q_data_type = model_runner.dtype
# Buffers and wrappers
self.qo_indptr = attn_backend.qo_indptr
self.workspace_buffer = attn_backend.workspace_buffer
self.fmha_backend = attn_backend.fmha_backend
self.chunk_ragged_wrappers = []
self.ragged_wrapper = attn_backend.prefill_wrapper_ragged
def update_prefix_chunks(self, num_prefix_chunks: int):
while num_prefix_chunks > len(self.chunk_ragged_wrappers):
ragged_wrapper = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend
)
self.chunk_ragged_wrappers.append(ragged_wrapper)
def update_wrapper(
self,
forward_batch: ForwardBatch,
):
assert forward_batch.num_prefix_chunks is not None
num_prefix_chunks = forward_batch.num_prefix_chunks
self.update_prefix_chunks(num_prefix_chunks)
prefix_lens = forward_batch.extend_prefix_lens
seq_lens = forward_batch.seq_lens
bs = len(seq_lens)
qo_indptr = self.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
for chunk_idx in range(forward_batch.num_prefix_chunks):
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
kv_indptr = forward_batch.prefix_chunk_cu_seq_lens[chunk_idx]
wrapper = self.chunk_ragged_wrappers[chunk_idx]
wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=False,
)
# ragged prefill
self.ragged_wrapper.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_local_heads,
num_kv_heads=self.num_local_heads,
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type,
causal=True,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
logits_soft_cap = layer.logit_cap
if forward_batch.attn_attend_prefix_cache:
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
wrapper = self.chunk_ragged_wrappers[chunk_idx]
o1, s1 = wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
else:
o1, s1 = self.ragged_wrapper.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
return o1, s1
class FlashInferMLAAttnBackend(AttentionBackend): class FlashInferMLAAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
...@@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -74,6 +183,12 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
)
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
# Allocate buffers # Allocate buffers
...@@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -117,11 +232,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else: else:
self.q_indptr_decode = q_indptr_decode_buf self.q_indptr_decode = q_indptr_decode_buf
fmha_backend = "auto" self.fmha_backend = "auto"
if is_sm100_supported(): if is_sm100_supported():
fmha_backend = "cutlass" self.fmha_backend = "cutlass"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=fmha_backend self.workspace_buffer, "NHD", backend=self.fmha_backend
) )
if not self.skip_prefill: if not self.skip_prefill:
...@@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -145,6 +260,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill( self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
model_runner, self model_runner, self
) )
if self.enable_chunk_kv:
self.mha_chunk_kv_cache = FlashInferMhaChunkKVRunner(model_runner, self)
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode( self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
model_runner, self model_runner, self
...@@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -373,6 +490,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
def forward_extend( def forward_extend(
self, self,
q: torch.Tensor, q: torch.Tensor,
...@@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -384,6 +505,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
): ):
if (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
): # MHA Chunk
assert self.enable_chunk_kv
assert q_rope is None
assert k_rope is None
o1, s1 = self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
return o1, s1
cache_loc = forward_batch.out_cache_loc cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
...@@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -412,8 +543,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
k = torch.cat([k, k_rope], dim=-1) k = torch.cat([k, k_rope], dim=-1)
o = self.prefill_wrapper_ragged.forward( o = self.prefill_wrapper_ragged.forward(
qall, qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v.view(-1, layer.tp_k_head_num, layer.v_head_dim), v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
causal=True, causal=True,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
...@@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill: ...@@ -732,6 +863,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
head_dim_vo=self.v_head_dim, head_dim_vo=self.v_head_dim,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
causal=True,
) )
else: else:
# mla paged prefill # mla paged prefill
......
...@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_symm_mem", "enable_symm_mem",
"quantization", "quantization",
"enable_custom_logit_processor", "enable_custom_logit_processor",
"disaggregation_mode",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -241,6 +241,9 @@ class ForwardBatch: ...@@ -241,6 +241,9 @@ class ForwardBatch:
prefix_chunk_num_tokens: Optional[List[int]] = None prefix_chunk_num_tokens: Optional[List[int]] = None
# KV Indices for each chunk # KV Indices for each chunk
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
# For MLA chunked prefix cache used in chunked prefill
# Tell attention backend whether lse needs to be returned
mha_return_lse: Optional[bool] = None
# For multimodal # For multimodal
mm_inputs: Optional[List[MultimodalInputs]] = None mm_inputs: Optional[List[MultimodalInputs]] = None
......
...@@ -518,9 +518,6 @@ class ModelRunner: ...@@ -518,9 +518,6 @@ class ModelRunner:
if not self.use_mla_backend: if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True server_args.disable_chunked_prefix_cache = True
elif self.page_size > 1:
logger.info("Disable chunked prefix cache when page size > 1.")
server_args.disable_chunked_prefix_cache = True
if not server_args.disable_chunked_prefix_cache: if not server_args.disable_chunked_prefix_cache:
logger.info("Chunked prefix cache is turned on.") logger.info("Chunked prefix cache is turned on.")
......
...@@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -995,29 +995,31 @@ class DeepseekV2AttentionMLA(nn.Module):
if attention_backend == "ascend": if attention_backend == "ascend":
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
elif attention_backend == "flashinfer": elif (
attention_backend == "flashinfer"
or attention_backend == "fa3"
or attention_backend == "flashmla"
):
# Use MHA with chunked KV cache when prefilling on long sequences.
sum_extend_prefix_lens = (
sum(forward_batch.extend_prefix_lens_cpu)
if forward_batch.extend_prefix_lens_cpu is not None
else 0
)
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
disable_ragged = (
attention_backend == "flashinfer" or attention_backend == "flashmla"
) and self.flashinfer_mla_disable_ragged
if ( if (
not self.flashinfer_mla_disable_ragged not disable_ragged
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
and sum(forward_batch.extend_prefix_lens_cpu) == 0
):
return AttnForwardMethod.MHA
else:
return _dispatch_mla_subtype()
elif attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None:
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
if (
forward_batch.forward_mode.is_extend()
and not self.disable_chunked_prefix_cache
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and ( and (
(
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
and not self.disable_chunked_prefix_cache
)
or sum_extend_prefix_lens == 0 or sum_extend_prefix_lens == 0
) )
): ):
...@@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1685,7 +1687,6 @@ class DeepseekV2AttentionMLA(nn.Module):
k[..., self.qk_nope_head_dim :] = k_pe k[..., self.qk_nope_head_dim :] = k_pe
output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
tmp_output = torch.empty_like(accum_output) tmp_output = torch.empty_like(accum_output)
tmp_lse = torch.empty_like(accum_lse) tmp_lse = torch.empty_like(accum_lse)
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse) merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
...@@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1707,55 +1708,26 @@ class DeepseekV2AttentionMLA(nn.Module):
# will be helpful for understanding the purpose of this function. # will be helpful for understanding the purpose of this function.
# First do normal mha forward to get output for extended part # First do normal mha forward to get output for extended part
if self.q_lora_rank is not None: return self.forward_normal_prepare(
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( positions, hidden_states, forward_batch, zero_allocator
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope = kv[..., : self.qk_nope_head_dim]
v = kv[..., self.qk_nope_head_dim :]
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
latent_cache[:, :, self.kv_lora_rank :] = k_pe
# Save latent cache
forward_batch.token_to_kv_pool.set_kv_buffer(
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
) )
return q, k, v, forward_batch
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
# Only initialize the info once
if has_extend_prefix and forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
forward_batch.mha_return_lse = has_extend_prefix
# Do mha for extended part without prefix # Do mha for extended part without prefix
forward_batch.set_attn_attend_prefix_cache(False) forward_batch.set_attn_attend_prefix_cache(False)
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
lse = torch.transpose(lse, 0, 1).contiguous()
# Do mha attention with chunked prefix cache if there are any sequence with prefix # Do mha attention with chunked prefix cache if there are any sequence with prefix
if any(forward_batch.extend_prefix_lens_cpu): if has_extend_prefix:
# Only initialize the info once attn_output, lse = attn_output
if forward_batch.num_prefix_chunks is None:
forward_batch.prepare_chunked_prefix_cache_info(q.device)
forward_batch.set_attn_attend_prefix_cache(True) forward_batch.set_attn_attend_prefix_cache(True)
attn_output = self._chunked_prefix_attn_mha( attn_output = self._chunked_prefix_attn_mha(
q=q, q=q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment