Commit 66979358 authored by liuchy5's avatar liuchy5
Browse files

feat:修复dsa的mqa接口兼容glm5

parent f7461a96
......@@ -112,19 +112,20 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
q_fp8.shape[1],
q_fp8.shape[2],
None,
True
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment