Commit 7676d0c9 authored by wangmin6's avatar wangmin6
Browse files

Merge branch 'v0.15.1-mqa_fp8' into 'v0.15.1-dev'

feat:支持mqa的fp8实现

See merge request dcutoolkit/deeplearing/vllm!514
parents 80f0794e b5323d90
......@@ -112,7 +112,7 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
elif torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
......@@ -121,7 +121,21 @@ def sparse_attn_indexer(
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_fp8,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k_fp8.shape[0],
q_fp8.shape[1],
q_fp8.shape[2],
k_scale.view(torch.float32).flatten(),
True
)
else:
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment