Unverified Commit 57eec0bf authored by lukec's avatar lukec Committed by GitHub
Browse files

fix FlashMLA cudagraph config (#4691)


Co-authored-by: default avataryinfan98 <1106310035@qq.com>
parent f01b0925
...@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None: if spec_info is None:
max_seqlen_pad = triton.cdiv( max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens.max().item(), PAGE_SIZE forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
) )
block_kv_indices = torch.full( block_kv_indices = torch.full(
(bs, max_seqlen_pad), (bs, max_seqlen_pad),
...@@ -206,8 +206,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -206,8 +206,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)]( create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices[:bs], req_pool_indices[:bs],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment