Unverified Commit e58423b2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix cutlass MLA gets almost zero accuracy (#6998)

parent 7059ae16
......@@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
max_seqlen_pad = self.cuda_graph_kv_indices.shape[1]
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
......@@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata = CutlassMLADecodeMetadata(
self.cuda_graph_mla_workspace,
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
......@@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
......@@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
super().init_forward_metadata_replay_cuda_graph(
bs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment