Unverified Commit 577df69b authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype...


[Bugfix] Fix KV scales inconsistency in fp8 MLA & FlashInfer kv_cache_dtype "auto" leading to gibberish (#37054)
Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent 04244fd0
...@@ -266,22 +266,6 @@ def create_and_prepopulate_kv_cache( ...@@ -266,22 +266,6 @@ def create_and_prepopulate_kv_cache(
return kv_cache return kv_cache
class MockAttentionLayer:
"""A mock attention layer for testing."""
def __init__(self, device: torch.device):
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward(self, *_args, **_kwargs):
raise NotImplementedError
class MockSparseMLAAttentionLayer: class MockSparseMLAAttentionLayer:
"""A mock sparse MLA attention layer for testing. """A mock sparse MLA attention layer for testing.
...@@ -304,6 +288,8 @@ class MockSparseMLAAttentionLayer: ...@@ -304,6 +288,8 @@ class MockSparseMLAAttentionLayer:
device: torch.device, device: torch.device,
W_UK: torch.Tensor, W_UK: torch.Tensor,
W_UV: torch.Tensor, W_UV: torch.Tensor,
q_scale: float,
k_scale: float,
): ):
self.impl = impl self.impl = impl
self.num_heads = num_heads self.num_heads = num_heads
...@@ -319,13 +305,13 @@ class MockSparseMLAAttentionLayer: ...@@ -319,13 +305,13 @@ class MockSparseMLAAttentionLayer:
self.W_UV = W_UV.transpose(0, 1) self.W_UV = W_UV.transpose(0, 1)
# Scale attributes needed by attention backends # Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device) self._q_scale = torch.tensor(q_scale, device=device)
self._k_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(k_scale, device=device)
self._v_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(float("nan"), device=device)
self._prob_scale = torch.tensor(1.0, device=device) self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0 self._q_scale_float = q_scale
self._k_scale_float = 1.0 self._k_scale_float = k_scale
self._v_scale_float = 1.0 self._v_scale_float = float("nan")
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True, static=True,
...@@ -420,6 +406,8 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -420,6 +406,8 @@ class MockMLAAttentionLayer(AttentionLayerBase):
kv_lora_rank: int, kv_lora_rank: int,
device: torch.device, device: torch.device,
kv_b_proj, kv_b_proj,
q_scale: float,
k_scale: float,
): ):
self.impl = impl self.impl = impl
self.num_heads = num_heads self.num_heads = num_heads
...@@ -443,13 +431,13 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -443,13 +431,13 @@ class MockMLAAttentionLayer(AttentionLayerBase):
self.W_UK_T = W_UK.permute(1, 2, 0) self.W_UK_T = W_UK.permute(1, 2, 0)
# Scale attributes needed by attention backends # Scale attributes needed by attention backends
self._q_scale = torch.tensor(1.0, device=device) self._q_scale = torch.tensor(q_scale, device=device)
self._k_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(k_scale, device=device)
self._v_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(float("nan"), device=device)
self._prob_scale = torch.tensor(1.0, device=device) self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0 self._q_scale_float = q_scale
self._k_scale_float = 1.0 self._k_scale_float = k_scale
self._v_scale_float = 1.0 self._v_scale_float = float("nan")
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True, static=True,
...@@ -568,6 +556,8 @@ def run_attention_backend( ...@@ -568,6 +556,8 @@ def run_attention_backend(
qk_rope_head_dim: int, qk_rope_head_dim: int,
v_head_dim: int, v_head_dim: int,
mock_kv_b_proj, mock_kv_b_proj,
q_scale: float,
k_scale: float,
kv_cache_dtype: str = "auto", kv_cache_dtype: str = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl.""" """Run attention computation using the specified backend's AttentionImpl."""
...@@ -625,6 +615,8 @@ def run_attention_backend( ...@@ -625,6 +615,8 @@ def run_attention_backend(
kv_lora_rank=kv_lora_rank, kv_lora_rank=kv_lora_rank,
device=device, device=device,
kv_b_proj=mock_kv_b_proj, kv_b_proj=mock_kv_b_proj,
q_scale=q_scale,
k_scale=k_scale,
) )
# Populate static_forward_context with mock attention layers # Populate static_forward_context with mock attention layers
...@@ -674,6 +666,7 @@ def run_attention_backend( ...@@ -674,6 +666,7 @@ def run_attention_backend(
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
def test_backend_correctness( def test_backend_correctness(
default_vllm_config, default_vllm_config,
dist_init, dist_init,
...@@ -681,6 +674,8 @@ def test_backend_correctness( ...@@ -681,6 +674,8 @@ def test_backend_correctness(
model: str, model: str,
tensor_parallel_size: int, tensor_parallel_size: int,
kv_cache_dtype: str, kv_cache_dtype: str,
q_scale: float,
k_scale: float,
): ):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
...@@ -709,6 +704,11 @@ def test_backend_correctness( ...@@ -709,6 +704,11 @@ def test_backend_correctness(
for b in BACKENDS_TO_TEST for b in BACKENDS_TO_TEST
if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes
] ]
if (
q_scale != 1.0 or k_scale != 1.0
) and AttentionBackendEnum.CUTLASS_MLA in backends_to_test:
# CUTLASS_MLA does not support non-1 Q/K scales
backends_to_test.remove(AttentionBackendEnum.CUTLASS_MLA)
if not backends_to_test: if not backends_to_test:
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}") pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")
...@@ -1029,6 +1029,7 @@ def test_backend_correctness( ...@@ -1029,6 +1029,7 @@ def test_backend_correctness(
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
randomize_blocks=True, randomize_blocks=True,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
scale=k_scale,
) )
kv_cache_per_block_size[block_size] = kv_cache kv_cache_per_block_size[block_size] = kv_cache
...@@ -1072,6 +1073,8 @@ def test_backend_correctness( ...@@ -1072,6 +1073,8 @@ def test_backend_correctness(
qk_rope_head_dim, qk_rope_head_dim,
v_head_dim, v_head_dim,
mock_kv_b_proj, mock_kv_b_proj,
q_scale=q_scale,
k_scale=k_scale,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
) )
......
...@@ -178,6 +178,7 @@ def _quantize_dequantize_fp8_ds_mla( ...@@ -178,6 +178,7 @@ def _quantize_dequantize_fp8_ds_mla(
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"]) @pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_ds_mla"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
@pytest.mark.parametrize("block_size", [32, 64]) @pytest.mark.parametrize("block_size", [32, 64])
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
def test_sparse_backend_decode_correctness( def test_sparse_backend_decode_correctness(
default_vllm_config, default_vllm_config,
dist_init, dist_init,
...@@ -187,6 +188,8 @@ def test_sparse_backend_decode_correctness( ...@@ -187,6 +188,8 @@ def test_sparse_backend_decode_correctness(
tensor_parallel_size, tensor_parallel_size,
block_size, block_size,
workspace_init, workspace_init,
q_scale: float,
k_scale: float,
): ):
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes: if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}") pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")
...@@ -332,7 +335,7 @@ def test_sparse_backend_decode_correctness( ...@@ -332,7 +335,7 @@ def test_sparse_backend_decode_correctness(
kv_c_contexts, k_pe_contexts = [], [] kv_c_contexts, k_pe_contexts = [], []
reference_outputs = [] reference_outputs = []
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) kv_cache_scale = torch.tensor(k_scale, dtype=torch.float32, device=device)
global_token_idx = 0 global_token_idx = 0
for i in range(batch_spec.batch_size): for i in range(batch_spec.batch_size):
...@@ -490,6 +493,8 @@ def test_sparse_backend_decode_correctness( ...@@ -490,6 +493,8 @@ def test_sparse_backend_decode_correctness(
device=device, device=device,
W_UK=W_UK, W_UK=W_UK,
W_UV=W_UV, W_UV=W_UV,
q_scale=q_scale,
k_scale=k_scale,
) )
out_buffer = torch.empty( out_buffer = torch.empty(
...@@ -513,7 +518,9 @@ def test_sparse_backend_decode_correctness( ...@@ -513,7 +518,9 @@ def test_sparse_backend_decode_correctness(
# FP8 quantization introduces some error, but should be within reasonable bounds # FP8 quantization introduces some error, but should be within reasonable bounds
# BF16 (auto) should be very accurate, FP8 allows slightly more tolerance # BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05) torch.testing.assert_close(
backend_output, sdpa_reference, rtol=0.065, atol=0.05
)
else: else:
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01) torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
......
...@@ -43,12 +43,12 @@ class MockAttentionLayer: ...@@ -43,12 +43,12 @@ class MockAttentionLayer:
"""Minimal mock of an attention layer for testing.""" """Minimal mock of an attention layer for testing."""
def __init__(self, device: torch.device): def __init__(self, device: torch.device):
self._q_scale = torch.tensor(1.0, device=device) self._q_scale = torch.tensor(2.0, device=device)
self._k_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(3.0, device=device)
self._v_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(4.0, device=device)
self._q_scale_float = 1.0 self._q_scale_float = 2.0
self._k_scale_float = 1.0 self._k_scale_float = 3.0
self._v_scale_float = 1.0 self._v_scale_float = 4.0
self._o_scale_float = None self._o_scale_float = None
......
...@@ -1319,10 +1319,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -1319,10 +1319,14 @@ class FlashInferImpl(AttentionImpl):
) )
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
self.bmm2_scale *= layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode) decode_use_trtllm = isinstance(attn_metadata.decode, TRTLLMDecode)
......
...@@ -255,6 +255,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -255,6 +255,11 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if layer._q_scale_float != 1.0 or layer._k_scale_float != 1.0:
raise NotImplementedError(
"CutlassMLAImpl does not support scaling for q and kv_latent yet"
)
if type(q) is tuple: if type(q) is tuple:
q_nope, q_pe = q q_nope, q_pe = q
else: else:
......
...@@ -177,9 +177,14 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -177,9 +177,14 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1]) q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
self.bmm2_scale *= layer._k_scale_float
o = trtllm_batch_decode_with_kv_cache_mla( o = trtllm_batch_decode_with_kv_cache_mla(
query=q, query=q,
......
...@@ -340,9 +340,13 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata ...@@ -340,9 +340,13 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
self._workspace_buffer = _get_workspace_buffer(q.device) self._workspace_buffer = _get_workspace_buffer(q.device)
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
self.bmm2_scale *= layer._k_scale_float
o = trtllm_batch_decode_with_kv_cache_mla( o = trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1), query=q.unsqueeze(1),
......
...@@ -187,7 +187,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -187,7 +187,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
self.scale, self.scale,
PAGE_SIZE, PAGE_SIZE,
k_scale=layer._k_scale, k_scale=layer._k_scale,
v_scale=layer._v_scale, v_scale=layer._k_scale,
) )
return o, lse return o, lse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment