Unverified Commit c4e5bdf6 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 7f1bcd18
...@@ -686,7 +686,7 @@ def sparse_attn_indexer( ...@@ -686,7 +686,7 @@ def sparse_attn_indexer(
fp8_mqa_logits_func = rocm_fp8_mqa_logits fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func( logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end], q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)), (k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end], weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
......
...@@ -249,8 +249,8 @@ def fp8_mqa_logits( ...@@ -249,8 +249,8 @@ def fp8_mqa_logits(
q: Query tensor of shape [M, H, D]. Casted to q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller. `torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
[N, 1]) with dtype `torch.float32`. with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`. weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position, cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32. shape [M], dtype int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment