Unverified Commit 1ba137e9 authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Enable trtllm mla prefix extend (#10526)

parent de28f8e7
...@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache: bool = True, save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None, q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor: ):
if ( if (
forward_batch.forward_mode.is_target_verify() forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend()
...@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ...@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
return_lse=forward_batch.mha_return_lse, return_lse=forward_batch.mha_return_lse,
) )
else: else:
# replace with trtllm ragged attention once accuracy is resolved. if not (
forward_batch.attn_attend_prefix_cache is not None
and forward_batch.mha_return_lse
):
output = super().forward_extend( output = super().forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
) )
else:
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert q_rope is None
assert k_rope is None
chunk_idx = forward_batch.prefix_chunk_idx
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
value=v,
workspace_buffer=self.workspace_buffer,
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
max_q_len=self.forward_prefill_metadata.max_seq_len,
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
bmm1_scale=layer.scaling,
bmm2_scale=1.0,
o_sf_scale=-1.0,
batch_size=forward_batch.batch_size,
window_left=-1,
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
enable_pdl=False,
is_causal=False,
return_lse=True,
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment