Unverified Commit decf7f79 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent d00d6529
...@@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k, v, return_softmax_lse): k, v, return_softmax_lse):
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None assert prefill.prefill_main is not None
return prefill.prefill_main.run( ret = prefill.prefill_main.run(
q=q, q=q,
k=k, k=k,
v=v, v=v,
return_lse=return_softmax_lse, return_lse=return_softmax_lse,
) )
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
q, k, v, return_softmax_lse): q, k, v, return_softmax_lse):
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
...@@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v): chunk_idx: int, q, k, v):
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
return prefill.prefill_chunks[chunk_idx].run( attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q, q=q,
k=k, k=k,
v=v, v=v,
return_lse=True, return_lse=True,
) )
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(self, def _run_prefill_context_chunk_cudnn(self,
prefill: MLACommonPrefillMetadata, prefill: MLACommonPrefillMetadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment