Unverified Commit dbfbf9f3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Fix FlashMLA metadata builder arguments for q_len > 1 (#27368)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ca76486a
......@@ -120,9 +120,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
num_decode_tokens: int,
dcp_tot_seq_lens_device: torch.Tensor | None,
) -> FlashMLADecodeMetadata:
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
# we use the max but all should be the same due to uniform length requirement
max_query_len = query_lens_cpu.max().item()
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
tile_scheduler_metadata, num_splits = get_mla_metadata(
seq_lens_device,
self.num_q_heads,
num_q_tokens_per_head_k,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment