Unverified Commit 8d6ab1cb authored by bmac3's avatar bmac3 Committed by GitHub
Browse files

fix seqlen bug for trtllm_mla's draft_extend (#12295)

parent 84a9d0ea
......@@ -944,8 +944,16 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
metadata.max_seq_len_k + forward_batch.spec_info.draft_token_num
)
else:
seq_lens = forward_batch.seq_lens.to(torch.int32)
max_seq_len = metadata.max_seq_len_k
# forward_batch.seq_lens is the seq_lens of the prev_context + verified tokens.
# To account for pad_draft_extend_query, we need seq_lens = prev_context + max_draft_tokens.
# This will ensure queries align with kvs correctly when calling
# flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla.
seq_lens = (
forward_batch.seq_lens
- metadata.seq_lens_q
+ metadata.max_seq_len_q
).to(torch.int32)
max_seq_len = metadata.max_seq_len_k + metadata.max_seq_len_q
# Check if we're in CUDA graph mode (buffers are pre-allocated)
if self.padded_q_buffer is not None:
# Use pre-allocated buffer for CUDA graph compatibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment