"vscode:/vscode.git/clone" did not exist on "97ba96fbe9a53416b973bbe05eb3bfefd6408a6c"
Unverified Commit c4e5bdf6 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 7f1bcd18
......@@ -686,7 +686,7 @@ def sparse_attn_indexer(
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
......
......@@ -249,8 +249,8 @@ def fp8_mqa_logits(
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment