Unverified Commit f24b2de3 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Test] Add FP8 KV Cache Testing for MLA Backends (#34473)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
parent fac1507f
......@@ -19,8 +19,13 @@ from tests.v1.attention.utils import (
)
from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport
from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport,
_DecodeConcatQuantFP8,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
......@@ -50,6 +55,7 @@ if not flash_attn_supports_mla():
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
......@@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype,
the cache is populated via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.
......@@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
fp8_attention = kv_cache_dtype and kv_cache_dtype.startswith("fp8")
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
if fp8_attention:
if use_fp8_ds_mla:
if not kv_c_contexts:
raise ValueError(
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
)
kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
# 4 * 4: 4 float32 scale values for 128-element tiles
# 2 * rope_dim: 16-bit RoPE values
kv_entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
else:
kv_entry_size = head_size
kv_cache = torch.zeros(
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device
num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device
)
scale_tensor = (
scale
......@@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache(
start = start_block_idx * block_size
if use_fp8_ds_mla:
if fp8_attention:
slots = torch.arange(context_len, device=device, dtype=torch.long) + start
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
kv_cache,
slots,
kv_cache_dtype="fp8_ds_mla",
kv_cache_dtype=kv_cache_dtype,
scale=scale_tensor,
)
else:
......@@ -329,8 +337,9 @@ class MockSparseMLAAttentionLayer:
output: torch.Tensor,
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
# Write to KV cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
......@@ -426,6 +435,12 @@ class MockMLAAttentionLayer(AttentionLayerBase):
self._k_scale_float = 1.0
self._v_scale_float = 1.0
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
def get_attn_backend(self):
raise NotImplementedError
......@@ -443,16 +458,21 @@ class MockMLAAttentionLayer(AttentionLayerBase):
) -> torch.Tensor:
"""Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
fp8_attention = kv_cache_dtype.startswith("fp8")
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
kv_c,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype="auto",
kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale,
)
if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
# Determine decode vs prefill split
num_decode_tokens = attn_metadata.num_decode_tokens or 0
has_decode = (attn_metadata.num_decodes or 0) > 0
......@@ -491,7 +511,13 @@ class MockMLAAttentionLayer(AttentionLayerBase):
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
# Pass as tuple to forward_mqa
if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe)
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
......@@ -526,6 +552,7 @@ def run_attention_backend(
qk_rope_head_dim: int,
v_head_dim: int,
mock_kv_b_proj,
kv_cache_dtype: str = "auto",
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
......@@ -550,7 +577,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
......@@ -630,12 +657,14 @@ def run_attention_backend(
)
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_backend_correctness(
default_vllm_config,
dist_init,
batch_spec_name: str,
model: str,
tensor_parallel_size: int,
kv_cache_dtype: str,
):
"""
Test that all backends produce similar outputs to a reference implementation
......@@ -658,9 +687,18 @@ def test_backend_correctness(
head counts.
"""
# Filter backends to those that support the requested kv_cache_dtype
backends_to_test = [
b
for b in BACKENDS_TO_TEST
if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes
]
if not backends_to_test:
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")
batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
default_block_size = unique_block_sizes[0]
required_blocks = sum(
(seq_len + default_block_size - 1) // default_block_size
......@@ -694,6 +732,7 @@ def test_backend_correctness(
block_size=default_block_size,
hf_config_override=hf_config_override,
)
vllm_config.cache_config.cache_dtype = kv_cache_dtype
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if is_spec_decode_test:
......@@ -751,7 +790,7 @@ def test_backend_correctness(
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
for i, backend in enumerate(backends_to_test):
all_sdpa_outputs.append([])
for i in range(batch_size):
......@@ -785,7 +824,7 @@ def test_backend_correctness(
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode = []
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
builder_cls, _ = try_get_attention_backend(backend)
if is_spec_decode_test:
query_len_support = getattr(
......@@ -885,7 +924,7 @@ def test_backend_correctness(
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
if is_decode[backend_idx]:
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
else:
......@@ -905,7 +944,7 @@ def test_backend_correctness(
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_outputs = {}
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend in enumerate(backends_to_test):
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
# Create mock kv_b_proj using the same weights as reference implementation
......@@ -973,12 +1012,13 @@ def test_backend_correctness(
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
kv_cache_dtype=kv_cache_dtype,
)
kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
for backend_idx, backend_name in enumerate(backends_to_test):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue
......@@ -997,7 +1037,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype,
cache_dtype_str=kv_cache_dtype,
)
backend_output = run_attention_backend(
......@@ -1016,6 +1056,7 @@ def test_backend_correctness(
qk_rope_head_dim,
v_head_dim,
mock_kv_b_proj,
kv_cache_dtype=kv_cache_dtype,
)
# Use backend_idx to get the correct SDPA output for this backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment