Unverified Commit f24b2de3 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Test] Add FP8 KV Cache Testing for MLA Backends (#34473)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
parent fac1507f
...@@ -19,8 +19,13 @@ from tests.v1.attention.utils import ( ...@@ -19,8 +19,13 @@ from tests.v1.attention.utils import (
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport,
_DecodeConcatQuantFP8,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backend import CommonAttentionMetadata
...@@ -50,6 +55,7 @@ if not flash_attn_supports_mla(): ...@@ -50,6 +55,7 @@ if not flash_attn_supports_mla():
if not is_flashmla_dense_supported()[0]: if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA) BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
SPEC_DECODE_BACKENDS = [] SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST: for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend) builder_cls, _ = try_get_attention_backend(backend)
...@@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache( ...@@ -144,9 +150,8 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks randomize_blocks: Whether to randomly permute blocks
or use sequential order or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype,
"fp8_ds_mla" the cache is populated using the the cache is populated via concat_and_cache_mla.
fp8 DeepSeek MLA layout via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested. fp8 cache layout is requested.
...@@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache( ...@@ -163,18 +168,21 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
fp8_attention = kv_cache_dtype and kv_cache_dtype.startswith("fp8")
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla" use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
if fp8_attention:
if use_fp8_ds_mla: if use_fp8_ds_mla:
if not kv_c_contexts:
raise ValueError(
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
)
kv_lora_rank = kv_c_contexts[0].shape[-1] kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1] rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim # 4 * 4: 4 float32 scale values for 128-element tiles
# 2 * rope_dim: 16-bit RoPE values
kv_entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
else:
kv_entry_size = head_size
kv_cache = torch.zeros( kv_cache = torch.zeros(
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device
) )
scale_tensor = ( scale_tensor = (
scale scale
...@@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache( ...@@ -201,14 +209,14 @@ def create_and_prepopulate_kv_cache(
start = start_block_idx * block_size start = start_block_idx * block_size
if use_fp8_ds_mla: if fp8_attention:
slots = torch.arange(context_len, device=device, dtype=torch.long) + start slots = torch.arange(context_len, device=device, dtype=torch.long) + start
ops.concat_and_cache_mla( ops.concat_and_cache_mla(
kv_c_context, kv_c_context,
k_pe_context.squeeze(1), k_pe_context.squeeze(1),
kv_cache, kv_cache,
slots, slots,
kv_cache_dtype="fp8_ds_mla", kv_cache_dtype=kv_cache_dtype,
scale=scale_tensor, scale=scale_tensor,
) )
else: else:
...@@ -329,8 +337,9 @@ class MockSparseMLAAttentionLayer: ...@@ -329,8 +337,9 @@ class MockSparseMLAAttentionLayer:
output: torch.Tensor, output: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens.""" """Forward for sparse MLA - uses forward_mqa for all tokens."""
# Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto") kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
# Write to KV cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
ops.concat_and_cache_mla( ops.concat_and_cache_mla(
kv_c, kv_c,
...@@ -426,6 +435,12 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -426,6 +435,12 @@ class MockMLAAttentionLayer(AttentionLayerBase):
self._k_scale_float = 1.0 self._k_scale_float = 1.0
self._v_scale_float = 1.0 self._v_scale_float = 1.0
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
def get_attn_backend(self): def get_attn_backend(self):
raise NotImplementedError raise NotImplementedError
...@@ -443,16 +458,21 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -443,16 +458,21 @@ class MockMLAAttentionLayer(AttentionLayerBase):
) -> torch.Tensor: ) -> torch.Tensor:
"""Replicates MLAAttention.forward_impl logic for testing.""" """Replicates MLAAttention.forward_impl logic for testing."""
# Write to KV cache # Write to KV cache
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
fp8_attention = kv_cache_dtype.startswith("fp8")
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
ops.concat_and_cache_mla( ops.concat_and_cache_mla(
kv_c, kv_c,
k_pe.squeeze(1), k_pe.squeeze(1),
kv_cache, kv_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
kv_cache_dtype="auto", kv_cache_dtype=kv_cache_dtype,
scale=self._k_scale, scale=self._k_scale,
) )
if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())
# Determine decode vs prefill split # Determine decode vs prefill split
num_decode_tokens = attn_metadata.num_decode_tokens or 0 num_decode_tokens = attn_metadata.num_decode_tokens or 0
has_decode = (attn_metadata.num_decodes or 0) > 0 has_decode = (attn_metadata.num_decodes or 0) > 0
...@@ -491,7 +511,13 @@ class MockMLAAttentionLayer(AttentionLayerBase): ...@@ -491,7 +511,13 @@ class MockMLAAttentionLayer(AttentionLayerBase):
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1) mqa_ql_nope = mqa_ql_nope.transpose(0, 1)
# Pass as tuple to forward_mqa if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe) mqa_q = (mqa_ql_nope, mqa_q_pe)
attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
...@@ -526,6 +552,7 @@ def run_attention_backend( ...@@ -526,6 +552,7 @@ def run_attention_backend(
qk_rope_head_dim: int, qk_rope_head_dim: int,
v_head_dim: int, v_head_dim: int,
mock_kv_b_proj, mock_kv_b_proj,
kv_cache_dtype: str = "auto",
) -> torch.Tensor: ) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl.""" """Run attention computation using the specified backend's AttentionImpl."""
...@@ -550,7 +577,7 @@ def run_attention_backend( ...@@ -550,7 +577,7 @@ def run_attention_backend(
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
kv_cache_dtype="auto", kv_cache_dtype=kv_cache_dtype,
logits_soft_cap=None, logits_soft_cap=None,
attn_type="decoder", attn_type="decoder",
kv_sharing_target_layer_name=None, kv_sharing_target_layer_name=None,
...@@ -630,12 +657,14 @@ def run_attention_backend( ...@@ -630,12 +657,14 @@ def run_attention_backend(
) )
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"]) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16]) @pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_backend_correctness( def test_backend_correctness(
default_vllm_config, default_vllm_config,
dist_init, dist_init,
batch_spec_name: str, batch_spec_name: str,
model: str, model: str,
tensor_parallel_size: int, tensor_parallel_size: int,
kv_cache_dtype: str,
): ):
""" """
Test that all backends produce similar outputs to a reference implementation Test that all backends produce similar outputs to a reference implementation
...@@ -658,9 +687,18 @@ def test_backend_correctness( ...@@ -658,9 +687,18 @@ def test_backend_correctness(
head counts. head counts.
""" """
# Filter backends to those that support the requested kv_cache_dtype
backends_to_test = [
b
for b in BACKENDS_TO_TEST
if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes
]
if not backends_to_test:
pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode") is_spec_decode_test = batch_spec_name.startswith("spec_decode")
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values())) unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
default_block_size = unique_block_sizes[0] default_block_size = unique_block_sizes[0]
required_blocks = sum( required_blocks = sum(
(seq_len + default_block_size - 1) // default_block_size (seq_len + default_block_size - 1) // default_block_size
...@@ -694,6 +732,7 @@ def test_backend_correctness( ...@@ -694,6 +732,7 @@ def test_backend_correctness(
block_size=default_block_size, block_size=default_block_size,
hf_config_override=hf_config_override, hf_config_override=hf_config_override,
) )
vllm_config.cache_config.cache_dtype = kv_cache_dtype
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if is_spec_decode_test: if is_spec_decode_test:
...@@ -751,7 +790,7 @@ def test_backend_correctness( ...@@ -751,7 +790,7 @@ def test_backend_correctness(
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST): for i, backend in enumerate(backends_to_test):
all_sdpa_outputs.append([]) all_sdpa_outputs.append([])
for i in range(batch_size): for i in range(batch_size):
...@@ -785,7 +824,7 @@ def test_backend_correctness( ...@@ -785,7 +824,7 @@ def test_backend_correctness(
# pipeline (MHA-style). This ensures the reference implementation # pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path. # matches each backend's actual decode/prefill pipeline path.
is_decode = [] is_decode = []
for backend_idx, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(backends_to_test):
builder_cls, _ = try_get_attention_backend(backend) builder_cls, _ = try_get_attention_backend(backend)
if is_spec_decode_test: if is_spec_decode_test:
query_len_support = getattr( query_len_support = getattr(
...@@ -885,7 +924,7 @@ def test_backend_correctness( ...@@ -885,7 +924,7 @@ def test_backend_correctness(
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for backend_idx, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(backends_to_test):
if is_decode[backend_idx]: if is_decode[backend_idx]:
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode) all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
else: else:
...@@ -905,7 +944,7 @@ def test_backend_correctness( ...@@ -905,7 +944,7 @@ def test_backend_correctness(
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_outputs = {} sdpa_outputs = {}
for backend_idx, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(backends_to_test):
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0) sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
# Create mock kv_b_proj using the same weights as reference implementation # Create mock kv_b_proj using the same weights as reference implementation
...@@ -973,12 +1012,13 @@ def test_backend_correctness( ...@@ -973,12 +1012,13 @@ def test_backend_correctness(
num_blocks=num_blocks_for_size, num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
randomize_blocks=True, randomize_blocks=True,
kv_cache_dtype=kv_cache_dtype,
) )
kv_cache_per_block_size[block_size] = kv_cache kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
failures = [] failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): for backend_idx, backend_name in enumerate(backends_to_test):
# Skip backends that don't support spec decode for spec decode tests # Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS: if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue continue
...@@ -997,7 +1037,7 @@ def test_backend_correctness( ...@@ -997,7 +1037,7 @@ def test_backend_correctness(
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
cache_dtype_str=vllm_config.cache_config.cache_dtype, cache_dtype_str=kv_cache_dtype,
) )
backend_output = run_attention_backend( backend_output = run_attention_backend(
...@@ -1016,6 +1056,7 @@ def test_backend_correctness( ...@@ -1016,6 +1056,7 @@ def test_backend_correctness(
qk_rope_head_dim, qk_rope_head_dim,
v_head_dim, v_head_dim,
mock_kv_b_proj, mock_kv_b_proj,
kv_cache_dtype=kv_cache_dtype,
) )
# Use backend_idx to get the correct SDPA output for this backend # Use backend_idx to get the correct SDPA output for this backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment