Unverified Commit e58423b2 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix cutlass MLA gets almost zero accuracy (#6998)

parent 7059ae16
...@@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -157,7 +157,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
if spec_info is None: if spec_info is None:
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) max_seqlen_pad = self.cuda_graph_kv_indices.shape[1]
create_flashmla_kv_indices_triton[(bs,)]( create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -169,12 +169,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE, PAGE_SIZE,
) )
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata = CutlassMLADecodeMetadata( self.forward_metadata = CutlassMLADecodeMetadata(
self.cuda_graph_mla_workspace, self.cuda_graph_mla_workspace,
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad], self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
...@@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -205,8 +199,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)]( create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices[:bs], req_pool_indices[:bs],
...@@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -217,16 +210,6 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_kv_indices.stride(0), self.cuda_graph_kv_indices.stride(0),
PAGE_SIZE, PAGE_SIZE,
) )
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
)
self.forward_metadata.workspace = self.cuda_graph_mla_workspace
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else: else:
super().init_forward_metadata_replay_cuda_graph( super().init_forward_metadata_replay_cuda_graph(
bs, bs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment