Unverified Commit 4145e50d authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Bugfix] Fix DSV3.2 NVFP4 (#33932)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 20f5d185
......@@ -530,7 +530,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
scale=self._k_scale,
)
if fp8_attention:
if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
# Sparse MLA impls only support forward_mqa (decode-style attention)
......@@ -614,7 +614,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
if fp8_attention:
if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
......@@ -1885,6 +1885,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
self.supports_quant_query_input = True
# Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions:
# num_heads=128, nope_dim=128, rope_dim=64
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment