Unverified Commit 82af928c authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention][Spec Decode] FlashMLA spec decode support (#26541)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 87efc681
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency.""" """Tests for v1 MLA backends without GPUModelRunner dependency.
Known Issues:
- FLASH_ATTN_MLA backend occasionally produces NaN values in
test_backend_correctness[mixed_small] when run after
test_backend_correctness[small_prefill], but passes when run alone.
"""
import pytest import pytest
import torch import torch
...@@ -14,6 +20,8 @@ from tests.v1.attention.utils import ( ...@@ -14,6 +20,8 @@ from tests.v1.attention.utils import (
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.config.vllm import set_current_vllm_config
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
...@@ -29,6 +37,10 @@ BACKENDS_TO_TEST = [ ...@@ -29,6 +37,10 @@ BACKENDS_TO_TEST = [
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
torch.manual_seed(42) torch.manual_seed(42)
...@@ -66,6 +78,12 @@ BATCH_SPECS = { ...@@ -66,6 +78,12 @@ BATCH_SPECS = {
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
"spec_decode_small": BatchSpec(
seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
),
"spec_decode_medium": BatchSpec(
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
),
} }
...@@ -239,61 +257,64 @@ def run_attention_backend( ...@@ -239,61 +257,64 @@ def run_attention_backend(
builder_cls, impl_cls = try_get_attention_backend(backend) builder_cls, impl_cls = try_get_attention_backend(backend)
# Build metadata # Set the current vllm config so that get_current_vllm_config() works
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) # in the backend implementations
attn_metadata = builder.build( with set_current_vllm_config(vllm_config):
common_prefix_len=0, # Build metadata
common_attn_metadata=common_attn_metadata, builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
) attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation # Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads( num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config vllm_config.parallel_config
) )
num_kv_heads = vllm_config.model_config.get_num_kv_heads( num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config vllm_config.parallel_config
) )
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5) scale = 1.0 / (head_size**0.5)
impl = impl_cls( impl = impl_cls(
num_heads=num_heads, num_heads=num_heads,
head_size=head_size, head_size=head_size,
scale=scale, scale=scale,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
kv_cache_dtype="auto", kv_cache_dtype="auto",
logits_soft_cap=None, logits_soft_cap=None,
attn_type="decoder", attn_type="decoder",
kv_sharing_target_layer_name=None, kv_sharing_target_layer_name=None,
q_lora_rank=None, q_lora_rank=None,
kv_lora_rank=kv_lora_rank, kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim, qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim, qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim, v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj, kv_b_proj=mock_kv_b_proj,
) )
# Process weights to create W_UK_T and W_UV attributes needed by MLA # Process weights to create W_UK_T and W_UV attributes needed by MLA
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype) impl.process_weights_after_loading(act_dtype)
# Create mock layer and output buffer # Create mock layer and output buffer
mock_layer = MockAttentionLayer(device) mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0] num_tokens = query.shape[0]
output = torch.empty( output = torch.empty(
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
) )
# Run forward pass # Run forward pass
# NOTE: The query, key, and value are already shaped correctly # NOTE: The query, key, and value are already shaped correctly
# in the calling test function. # in the calling test function.
output = impl.forward( output = impl.forward(
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
) )
return output return output
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -309,6 +330,8 @@ def run_attention_backend( ...@@ -309,6 +330,8 @@ def run_attention_backend(
"large_prefill", "large_prefill",
"single_decode", "single_decode",
"single_prefill", "single_prefill",
"spec_decode_small",
"spec_decode_medium",
], ],
) )
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) @pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
...@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
""" """
from vllm.v1.attention.backends.mla.common import QueryLenSupport
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
block_size = 16
required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100
vllm_config = create_vllm_config( vllm_config = create_vllm_config(
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048 model_name=model,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=block_size,
) )
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
if is_spec_decode_test:
from vllm.config import SpeculativeConfig
# Get the query length from the batch spec (they should all be uniform)
query_len = batch_spec.query_lens[0]
# Set num_speculative_tokens to query_len - 1
# (since threshold is 1 + num_spec_tokens)
# Use ngram method which doesn't require a draft model
vllm_config.speculative_config = SpeculativeConfig(
method="ngram", num_speculative_tokens=query_len - 1
)
device = torch.device("cuda:0") device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
...@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# K_PE (rope component): [s_len, 1, qk_rope_head_dim] # K_PE (rope component): [s_len, 1, qk_rope_head_dim]
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device) k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
# Determine if this is decode or prefill # Determine if this sequence uses the decode pipeline or prefill
# pipeline for each backend
# NOTE: For spec decode tests with uniform query_len > 1, backends that
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
# uniform support) will use the decode pipeline (MQA-style), while
# backends that only support single-token queries will use the prefill
# pipeline (MHA-style). This ensures the reference implementation
# matches each backend's actual decode/prefill pipeline path.
is_decode = [] is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = try_get_attention_backend(backend) builder_cls, _ = try_get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold) if is_spec_decode_test:
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
is_decode.append(supports_spec)
else:
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
within_threshold = q_len <= threshold if threshold else False
if (
within_threshold
and query_len_support == QueryLenSupport.UNIFORM
and i > 0
):
first_q_len = query_lens[0]
within_threshold = q_len == first_q_len
is_decode.append(within_threshold)
# Split q into nope and rope components # Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
...@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for i, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]: if is_decode[backend_idx]:
all_sdpa_outputs[i].append(sdpa_out_i_decode) all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
else: else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill) all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)
# Inputs for vLLM MLA backends are just the new tokens # Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c) all_q_vllm.append(q_c)
...@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0) query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_outputs = [] sdpa_outputs = {}
for i, backend in enumerate(BACKENDS_TO_TEST): for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
# Create mock kv_b_proj using the same weights as reference implementation # Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
...@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_b_proj_weight = kv_b_proj_weight.view( kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim) kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
) )
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
# Create metadata using original batch spec # Create metadata using original batch spec
common_attn_metadata = create_common_attn_metadata( common_attn_metadata = create_common_attn_metadata(
...@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
) )
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
for i, backend_name in enumerate(BACKENDS_TO_TEST): for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends:
continue
backend_output = run_attention_backend( backend_output = run_attention_backend(
backend_name, backend_name,
kv_cache_spec, kv_cache_spec,
...@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj, mock_kv_b_proj,
) )
# Use backend_idx to get the correct SDPA output for this backend
expected_output = sdpa_outputs[backend_name]
# Check shape and dtype consistency # Check shape and dtype consistency
assert backend_output.shape == sdpa_outputs[i].shape, ( assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != " f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_outputs[i].shape}" f"SDPA shape {expected_output.shape}"
) )
assert backend_output.dtype == sdpa_outputs[i].dtype, ( assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != " f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_outputs[i].dtype}" f"SDPA dtype {expected_output.dtype}"
) )
assert torch.isfinite(backend_output).all(), ( assert torch.isfinite(backend_output).all(), (
...@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2 rtol = 1e-2
atol = 5e-1 atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item() max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max( max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i]) torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item() ).item()
all_close = torch.allclose( all_close = torch.allclose(
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol backend_output, expected_output, rtol=rtol, atol=atol
) )
assert all_close, ( assert all_close, (
......
...@@ -190,6 +190,7 @@ return curr_o @ W_O ...@@ -190,6 +190,7 @@ return curr_o @ W_O
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum
from typing import ClassVar, Generic, TypeVar from typing import ClassVar, Generic, TypeVar
import torch import torch
...@@ -227,6 +228,24 @@ from vllm.v1.attention.backends.utils import ( ...@@ -227,6 +228,24 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
class QueryLenSupport(Enum):
"""Defines the level of query length support for an attention backend's
decode pipeline.
- SINGLE_ONLY: Decode pipeline only supports single-token queries
(query_len=1)
- UNIFORM: Decode pipeline supports uniform multi-token queries
(all requests must have same query_len > 1)
- VARLEN: Decode pipeline supports variable-length queries
(mixed query lengths in same batch)
"""
SINGLE_ONLY = "single_only"
UNIFORM = "uniform"
VARLEN = "varlen"
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -460,19 +479,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -460,19 +479,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
understand this class understand this class
""" """
# Whether the backend supports reordering the batch such that # Defines the level of query length support for this backend.
# short sequences (i.e. verification for speculative decoding) are # - SINGLE_ONLY: Only single-token queries (no spec decode support)
# classified as decode requests. # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
# If True, this will increase `reorder_batch_threshold` (below) when # - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
# speculative decoding is enabled, and set `require_uniform=True` when # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
# when reordering the batch. Non-uniform decode requests will # speculative decoding is enabled.
# fall back to prefill in this case. query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
supports_uniform_spec_as_decode: ClassVar[bool] = False
# The threshold for reordering the batch into decode and prefill requests. # The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with # If > 1, the batch will be reordered such that requests with
# query length <= threshold are classified as decode requests. # query length <= threshold are classified as decode requests.
# Use `supports_uniform_spec_as_decode` (above) to set this automatically # Use `query_len_support` (above) to set this automatically
# when speculative decoding is enabled. # when speculative decoding is enabled.
reorder_batch_threshold: int = 1 reorder_batch_threshold: int = 1
...@@ -599,11 +617,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -599,11 +617,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device, device=device,
) )
supports_spec_as_decode = self.supports_uniform_spec_as_decode supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold( self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_as_decode self.reorder_batch_threshold, supports_spec_decode
) )
# Validate consistency between query_len_support and reorder_batch_threshold
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
assert self.reorder_batch_threshold == 1, (
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
)
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc qo_indptr = prefill.query_start_loc
...@@ -745,7 +770,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -745,7 +770,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, common_attn_metadata,
decode_threshold=self.reorder_batch_threshold, decode_threshold=self.reorder_batch_threshold,
require_uniform=self.supports_uniform_spec_as_decode, require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
) )
) )
......
...@@ -24,6 +24,7 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -24,6 +24,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -66,8 +67,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): ...@@ -66,8 +67,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]): class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
reorder_batch_threshold: int = 512 reorder_batch_threshold: int = 512 # process small prefills with decode pathway
def __init__( def __init__(
self, self,
......
...@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -13,6 +13,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
...@@ -22,11 +23,8 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 ...@@ -22,11 +23,8 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable spec-as-decode optimization
supports_uniform_spec_as_decode: ClassVar[bool] = True
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
class FlashInferMLABackend(MLACommonBackend): class FlashInferMLABackend(MLACommonBackend):
......
...@@ -20,8 +20,13 @@ from vllm.v1.attention.backends.mla.common import ( ...@@ -20,8 +20,13 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,
) )
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -62,6 +67,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): ...@@ -62,6 +67,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
# ^ TODO(matt): tune this
def __init__( def __init__(
self, self,
...@@ -216,8 +224,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -216,8 +224,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat(q, dim=-1) q = torch.cat(q, dim=-1)
assert isinstance(q, torch.Tensor) assert isinstance(q, torch.Tensor)
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
o, lse = flash_mla_with_kvcache( o, lse = flash_mla_with_kvcache(
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
...@@ -230,4 +242,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -230,4 +242,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
descale_k=layer._k_scale.reshape(1), descale_k=layer._k_scale.reshape(1),
) )
o = reshape_attn_output_for_spec_decode(o)
return o, lse return o, lse
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment