Commit b5323d90 authored by lixh6's avatar lixh6
Browse files

feat:支持mqa的fp8实现

parent 80f0794e
...@@ -112,7 +112,7 @@ def sparse_attn_indexer( ...@@ -112,7 +112,7 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
else: elif torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache( ops.cp_gather_indexer_k_quant_cache(
...@@ -121,7 +121,21 @@ def sparse_attn_indexer( ...@@ -121,7 +121,21 @@ def sparse_attn_indexer(
k_scale, k_scale,
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else:
logits = op.mqa_logits( logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
k, k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment