Unverified Commit 28ef9ba3 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent fb7fdc49
...@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model(): ...@@ -476,12 +476,12 @@ def test_set_inputs_first_pass_draft_model():
proposer.max_num_tokens, dtype=torch.bool, device=device proposer.max_num_tokens, dtype=torch.bool, device=device
) )
# Mock the attn_metadata_builder to avoid needing the full model setup # Mock draft_attn_groups to avoid needing the full model setup
mock_kv_cache_spec = mock.MagicMock() mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock() mock_attn_group = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2 # Request 0: query_len=3 (but 1 rejected), Request 1: query_len=2
batch_spec = BatchSpec( batch_spec = BatchSpec(
...@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting(): ...@@ -616,12 +616,12 @@ def test_set_inputs_first_pass_parallel_drafting():
proposer.max_num_tokens, dtype=torch.bool, device=device proposer.max_num_tokens, dtype=torch.bool, device=device
) )
# Mock the attn_metadata_builder # Mock draft_attn_groups
mock_kv_cache_spec = mock.MagicMock() mock_kv_cache_spec = mock.MagicMock()
mock_kv_cache_spec.block_size = block_size mock_kv_cache_spec.block_size = block_size
mock_builder = mock.MagicMock() mock_attn_group = mock.MagicMock()
mock_builder.kv_cache_spec = mock_kv_cache_spec mock_attn_group.kv_cache_spec = mock_kv_cache_spec
proposer.attn_metadata_builder = mock_builder proposer.draft_attn_groups = [mock_attn_group]
# Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid) # Request 0: query_len=4 (1 rejected), Request 1: query_len=4 (all valid)
batch_spec = BatchSpec( batch_spec = BatchSpec(
...@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -916,7 +916,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
proposer.model = model_mock proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked # Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"] proposer._draft_attn_layer_names = {"layer.0"}
# Create input tensors # Create input tensors
batch_spec = BatchSpec( batch_spec = BatchSpec(
...@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): ...@@ -961,20 +961,18 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config, vllm_config=proposer.vllm_config,
device=device, device=device,
) )
# Mock runner for attention metadata building # Mock runner and draft_attn_groups for attention metadata building
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) mock_attn_group = mock.MagicMock()
proposer.runner.attn_groups[0][ mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
0 mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
].get_metadata_builder.return_value = attn_metadata_builder mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer._get_attention_metadata_builder = mock.MagicMock( proposer.draft_attn_groups = [mock_attn_group]
return_value=attn_metadata_builder
)
result = proposer.propose( result = proposer.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
...@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree): ...@@ -1089,7 +1087,7 @@ def test_propose_tree(spec_token_tree):
proposer.model = model_mock proposer.model = model_mock
# Assign draft attn_layer_names since load_model is not invoked # Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"] proposer._draft_attn_layer_names = {"layer.0"}
# Get the tree attention metadata builder. # Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = try_get_attention_backend( attn_metadata_builder_cls, _ = try_get_attention_backend(
...@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree): ...@@ -1097,21 +1095,18 @@ def test_propose_tree(spec_token_tree):
) )
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=proposer._draft_attn_layer_names,
vllm_config=proposer.vllm_config, vllm_config=proposer.vllm_config,
device=device, device=device,
) )
# Mock runner for attention metadata building. # Mock runner and draft_attn_groups for attention metadata building.
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) mock_attn_group = mock.MagicMock()
proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
proposer.runner.attn_groups[0][ mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
0 mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
].get_metadata_builder.return_value = attn_metadata_builder proposer.draft_attn_groups = [mock_attn_group]
proposer._get_attention_metadata_builder = mock.MagicMock(
return_value=attn_metadata_builder
)
# Setup inputs for the proposer. # Setup inputs for the proposer.
target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device)
......
...@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): ...@@ -162,7 +162,7 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
model_mock.compute_logits.side_effect = logits_returns model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"] proposer._draft_attn_layer_names = {"layer.0"}
# Prepare inputs # Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
...@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch): ...@@ -190,13 +190,17 @@ def test_mtp_propose(num_speculative_tokens, monkeypatch):
attn_metadata_builder = attn_metadata_builder_cls( attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names, layer_names=list(proposer._draft_attn_layer_names),
vllm_config=proposer.vllm_config, vllm_config=proposer.vllm_config,
device=device, device=device,
) )
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder mock_attn_group = mock.MagicMock()
mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder
mock_attn_group.layer_names = list(proposer._draft_attn_layer_names)
mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec
proposer.draft_attn_groups = [mock_attn_group]
# Run propose # Run propose
result = proposer.propose( result = proposer.propose(
......
...@@ -79,6 +79,12 @@ def sparse_attn_indexer( ...@@ -79,6 +79,12 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
kv_cache, kv_cache,
......
...@@ -12,6 +12,7 @@ from vllm.utils.deep_gemm import ( ...@@ -12,6 +12,7 @@ from vllm.utils.deep_gemm import (
get_paged_mqa_logits_metadata, get_paged_mqa_logits_metadata,
is_deep_gemm_supported, is_deep_gemm_supported,
) )
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
...@@ -24,6 +25,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -24,6 +25,7 @@ from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
split_prefill_chunks, split_prefill_chunks,
) )
from vllm.v1.worker.cp_utils import get_total_cp_world_size
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -214,20 +216,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -214,20 +216,39 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
if self.vllm_config.speculative_config if self.vllm_config.speculative_config
else 0 else 0
) )
if self.num_speculative_tokens > 1:
raise ValueError(
"Sparse MLA only supports "
"num_speculative_tokens <= 1 because the DeepGEMM "
"fp8_paged_mqa_logits kernel does not support next_n > 2. "
f"Got num_speculative_tokens={self.num_speculative_tokens}."
)
self.reorder_batch_threshold += self.num_speculative_tokens self.reorder_batch_threshold += self.num_speculative_tokens
sm_count = num_compute_units(self.device.index) sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count self.num_sms = sm_count
self.decode_lens_buffer = torch.empty( self.decode_lens_buffer = torch.empty(
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device (scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
# Pre-allocated buffers for flattening (spec decode).
self.arange_buffer = torch.arange(
scheduler_config.max_num_seqs * (1 + self.num_speculative_tokens),
dtype=torch.int32,
device=self.device,
)
self.expanded_seq_lens_buffer = torch.zeros(
(scheduler_config.max_num_batched_tokens,),
dtype=torch.int32,
device=self.device,
)
max_num_blocks_per_req = cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size * get_total_cp_world_size(),
)
self.expanded_block_table_buffer = torch.zeros(
(
scheduler_config.max_num_batched_tokens,
max_num_blocks_per_req,
),
dtype=torch.int32,
device=self.device,
) )
# See: DeepGMM/csrc/apis/attention.hpp # See: DeepGMM/csrc/apis/attention.hpp
...@@ -326,42 +347,97 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -326,42 +347,97 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1] common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
) )
# Use CPU to avoid GPU sync; breaking async scheduling seq_lens = common_attn_metadata.seq_lens[:num_decodes]
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Decide which top-k kernel to use based on batch size and sequence length # Padded CUDA graph requests have block_table entries of -1.
batch_size = num_decodes # Clamp to 0 to prevent OOB access in the DeepGEMM kernel.
_is_large_context = common_attn_metadata.max_seq_len > 8192 # This is safe because padded requests have seq_lens=0, so the
# kernel produces no meaningful output for those rows.
block_table.clamp_(min=0)
# Decision logic based on micro-benchmark results: max_decode_len = int(decode_lens_cpu.max().item())
# - large_context_topk wins for batch <= 128 and seq_len > 8K if max_decode_len > 1:
# - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K # Flatten multi-token decode requests into single-token
use_large_context_topk = batch_size <= 128 and _is_large_context # batch entries, expanding seq_lens and block tables so
# the kernel always sees next_n=1.
# Assume 4 requests with seq_lens [10, 7, 12, 0] (the final req is
# padding) and decode_lens [3, 1, 4, 0] in the below example comments.
# The context lengths are therefore
# [10-3, 7-1, 12-4, 0-0] = [7, 6, 8, 0].
# 3 + 1 + 4 + 0 = 8
actual_expanded = int(decode_lens_cpu.sum().item())
# [7, 6, 8, 0] -> [7, 7, 7, 6, 8, 8, 8, 8]
expanded_base = torch.repeat_interleave(
seq_lens - decode_lens, decode_lens
)
# [0, 3, 4, 8] -> [0, 0, 0, 3, 4, 4, 4, 4]
expanded_starts = torch.repeat_interleave(
common_attn_metadata.query_start_loc[:num_decodes], decode_lens
)
# [0, 1, 2, 0, 0, 1, 2, 3]
positions_within = (
self.arange_buffer[:actual_expanded] - expanded_starts
)
# [8, 9, 10, 7, 9, 10, 11, 12, ...] where ... is unused buffer space
self.expanded_seq_lens_buffer[:actual_expanded] = (
expanded_base + positions_within + 1
)
self.expanded_seq_lens_buffer[actual_expanded:] = 0
seq_lens = self.expanded_seq_lens_buffer[:num_decode_tokens]
# Give each of the flattened entries the same block table row as the
# original request.
self.expanded_block_table_buffer[:actual_expanded] = (
torch.repeat_interleave(block_table, decode_lens, dim=0)
)
if actual_expanded < num_decode_tokens:
self.expanded_block_table_buffer[
actual_expanded:num_decode_tokens, 0
] = 0
block_table = self.expanded_block_table_buffer[:num_decode_tokens]
# All reqs now have decode_len=1
self.decode_lens_buffer[:num_decode_tokens] = 1
decode_lens = self.decode_lens_buffer[:num_decode_tokens]
offsets = None
batch_size = num_decode_tokens
else:
next_n = 1 + self.num_speculative_tokens next_n = 1 + self.num_speculative_tokens
if next_n > 1: if next_n > 1:
offsets = torch.arange(next_n, device=self.device, dtype=torch.int32) offsets = torch.arange(
next_n, device=self.device, dtype=torch.int32
)
else: else:
offsets = None offsets = None
batch_size = num_decodes
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
# DeepGEMM is required for the paged MQA logits on CUDA devices # DeepGEMM is required for the paged MQA logits on CUDA devices
if current_platform.is_cuda() and is_deep_gemm_supported(): if current_platform.is_cuda() and is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens,
self.kv_cache_spec.block_size,
self.num_sms,
) )
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]
# Padded CUDA graph requests have block_table entries of -1. # Decide which top-k kernel to use based on batch size and sequence length
# Clamp to 0 to prevent OOB access in the DeepGEMM kernel. # Decision logic based on micro-benchmark results:
# This is safe because padded requests have seq_lens=0, so the # - large_context_topk wins for batch <= 128 and seq_len > 8K
# kernel produces no meaningful output for those rows. # - top_k_per_row_decode wins for batch > 128 or seq_len <= 8K
block_table.clamp_(min=0) _is_large_context = common_attn_metadata.max_seq_len > 8192
use_large_context_topk = batch_size <= 128 and _is_large_context
decode_metadata = DeepSeekV32IndexerDecodeMetadata( decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table, block_table=block_table,
seq_lens=common_attn_metadata.seq_lens[:num_decodes], seq_lens=seq_lens,
decode_lens=decode_lens, decode_lens=decode_lens,
requires_padding=requires_padding, requires_padding=False,
schedule_metadata=self.scheduler_metadata_buffer, schedule_metadata=self.scheduler_metadata_buffer,
use_large_context_topk=use_large_context_topk, use_large_context_topk=use_large_context_topk,
offsets=offsets, offsets=offsets,
......
...@@ -20,17 +20,13 @@ from vllm.logger import init_logger ...@@ -20,17 +20,13 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import CommonAttentionMetadata
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import ( from vllm.v1.attention.backends.tree_attn import (
TreeAttentionMetadata, TreeAttentionMetadata,
...@@ -38,7 +34,7 @@ from vllm.v1.attention.backends.tree_attn import ( ...@@ -38,7 +34,7 @@ from vllm.v1.attention.backends.tree_attn import (
) )
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...@@ -53,6 +49,7 @@ from vllm.v1.spec_decode.utils import ( ...@@ -53,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -113,10 +110,8 @@ class SpecDecodeBaseProposer: ...@@ -113,10 +110,8 @@ class SpecDecodeBaseProposer:
vllm_config.model_config vllm_config.model_config
) )
self.attn_metadata_builder: AttentionMetadataBuilder | None = None self.draft_attn_groups: list[AttentionGroup] = []
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None self.kv_cache_gid: int = -1
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.eagle3_use_aux_hidden_state: bool = ( self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config() self._get_eagle3_use_aux_hidden_state_from_config()
) )
...@@ -353,7 +348,7 @@ class SpecDecodeBaseProposer: ...@@ -353,7 +348,7 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens] view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names} return {name: view for name in self._draft_attn_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle. """Initialize cudagraph dispatcher keys for eagle.
...@@ -420,34 +415,14 @@ class SpecDecodeBaseProposer: ...@@ -420,34 +415,14 @@ class SpecDecodeBaseProposer:
assert self.runner is not None assert self.runner is not None
if self.attn_metadata_builder is None: per_layer_attn_metadata: dict[str, object] = {}
attn_metadata_builder = self._get_attention_metadata_builder() for attn_group in self.draft_attn_groups:
else: attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
attn_metadata_builder = self.attn_metadata_builder
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0 common_attn_metadata=common_attn_metadata, draft_index=0
) )
# FIXME: support hybrid kv for draft model (remove separate indexer) for layer_name in attn_group.layer_names:
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens) self._determine_batch_execution_and_padding(num_tokens)
) )
...@@ -503,12 +478,7 @@ class SpecDecodeBaseProposer: ...@@ -503,12 +478,7 @@ class SpecDecodeBaseProposer:
positions = self.mrope_positions[:, token_indices_to_sample] positions = self.mrope_positions[:, token_indices_to_sample]
else: else:
positions = self.positions[token_indices_to_sample] positions = self.positions[token_indices_to_sample]
if self.method in ( if self.method == "mtp":
"deepseek_mtp",
"ernie_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
):
hidden_states = self.hidden_states[token_indices_to_sample] hidden_states = self.hidden_states[token_indices_to_sample]
else: else:
hidden_states = hidden_states[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample]
...@@ -613,7 +583,8 @@ class SpecDecodeBaseProposer: ...@@ -613,7 +583,8 @@ class SpecDecodeBaseProposer:
common_attn_metadata._num_computed_tokens_cpu += 1 common_attn_metadata._num_computed_tokens_cpu += 1
# Compute the slot mapping. # Compute the slot mapping.
block_size = attn_metadata_builder.kv_cache_spec.block_size # Use the first draft attention group's kv_cache_spec for block_size
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
if self.uses_mrope: if self.uses_mrope:
# all dimensions of positions are the same # all dimensions of positions are the same
block_numbers = clamped_positions[0] // block_size block_numbers = clamped_positions[0] // block_size
...@@ -639,10 +610,12 @@ class SpecDecodeBaseProposer: ...@@ -639,10 +610,12 @@ class SpecDecodeBaseProposer:
) )
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore for attn_group in self.draft_attn_groups:
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
) )
for layer_name in self.attn_layer_names: for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
...@@ -805,18 +778,17 @@ class SpecDecodeBaseProposer: ...@@ -805,18 +778,17 @@ class SpecDecodeBaseProposer:
# 2. # 2.
# Recompute the slot mapping based on the new positions and # Recompute the slot mapping based on the new positions and
# rejection mask. # rejection mask.
builder = ( # Use the first draft attention group's kv_cache_spec for block_size
self._get_attention_metadata_builder() # (all draft layers share the same kv-cache group)
if self.attn_metadata_builder is None assert len(self.draft_attn_groups) > 0
else self.attn_metadata_builder block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
)
new_slot_mapping = compute_new_slot_mapping( new_slot_mapping = compute_new_slot_mapping(
cad=cad, cad=cad,
new_positions=self.positions[:total_num_output_tokens], new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[ is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens :total_num_output_tokens
], ],
block_size=builder.kv_cache_spec.block_size, block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request, num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
) )
...@@ -1000,9 +972,7 @@ class SpecDecodeBaseProposer: ...@@ -1000,9 +972,7 @@ class SpecDecodeBaseProposer:
| list[dict[str, torch.Tensor]] | list[dict[str, torch.Tensor]]
| None = None, | None = None,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][ tree_attn_metadata_builder = self.draft_attn_groups[0].get_metadata_builder()
0
].get_metadata_builder()
assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder)
total_num_drafts = self.cu_drafts_per_level[0] total_num_drafts = self.cu_drafts_per_level[0]
...@@ -1078,9 +1048,10 @@ class SpecDecodeBaseProposer: ...@@ -1078,9 +1048,10 @@ class SpecDecodeBaseProposer:
common_attn_metadata=common_attn_metadata, draft_index=level + 1 common_attn_metadata=common_attn_metadata, draft_index=level + 1
) )
# Apply new attention metadata to all layers. # Apply new attention metadata to all draft layers.
per_layer_attn_metadata = {} per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names: for attn_group in self.draft_attn_groups:
for layer_name in attn_group.layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
# Consider max model length. # Consider max model length.
...@@ -1288,43 +1259,17 @@ class SpecDecodeBaseProposer: ...@@ -1288,43 +1259,17 @@ class SpecDecodeBaseProposer:
AttentionLayerBase, # type: ignore[type-abstract] AttentionLayerBase, # type: ignore[type-abstract]
).keys() ).keys()
) )
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
).keys()
)
self.model = self._get_model() self.model = self._get_model()
draft_attn_layer_names = ( # Find draft layers (attention layers added by draft model)
get_layers_from_vllm_config( all_attn_layers = get_layers_from_vllm_config(
self.vllm_config, self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract] AttentionLayerBase, # type: ignore[type-abstract]
).keys()
- target_attn_layer_names
)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
first_layer = self.indexer_layer_names[0]
self.draft_indexer_metadata_builder = (
indexer_layers[first_layer]
.get_attn_backend()
.get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names,
self.vllm_config,
self.device,
) )
self._draft_attn_layer_names = (
set(all_attn_layers.keys()) - target_attn_layer_names
) )
else:
self.draft_indexer_metadata_builder = None
if self.supports_mm_inputs: if self.supports_mm_inputs:
# Even if the target model is multimodal, we can also use # Even if the target model is multimodal, we can also use
...@@ -1562,9 +1507,9 @@ class SpecDecodeBaseProposer: ...@@ -1562,9 +1507,9 @@ class SpecDecodeBaseProposer:
# Make sure to use EAGLE's own buffer during cudagraph capture. # Make sure to use EAGLE's own buffer during cudagraph capture.
if ( if (
self.attn_layer_names self._draft_attn_layer_names
and slot_mappings is not None and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings and next(iter(self._draft_attn_layer_names)) in slot_mappings
): ):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens) slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else: else:
...@@ -1594,31 +1539,6 @@ class SpecDecodeBaseProposer: ...@@ -1594,31 +1539,6 @@ class SpecDecodeBaseProposer:
kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
self.model(**kwargs) self.model(**kwargs)
def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
"""Find and return the attention metadata builders for EAGLE layers.
Returns:
The metadata builders for EAGLE layers.
Raises:
AssertionError: If no metadata builders are found for EAGLE layers.
"""
builder = None
chosen_layer = self.attn_layer_names[0]
for kv_cache_group in self.runner.attn_groups:
for attn_group in kv_cache_group:
if chosen_layer in attn_group.layer_names:
builder = attn_group.get_metadata_builder()
break
if builder is not None:
break
assert builder is not None, (
"Failed to find attention metadata builder for EAGLE layers."
)
return builder
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
""" """
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
...@@ -1651,13 +1571,71 @@ class SpecDecodeBaseProposer: ...@@ -1651,13 +1571,71 @@ class SpecDecodeBaseProposer:
set( set(
[ [
kv_cache_groups[layer_name] kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names for layer_name in self._draft_attn_layer_names
] ]
) )
) )
== 1 == 1
), "All drafting layers should belong to the same kv cache group" ), "All drafting layers should belong to the same kv cache group"
def initialize_attn_backend(
self,
kv_cache_config: KVCacheConfig,
kernel_block_sizes: list[int] | None = None,
) -> None:
"""
Initialize AttentionGroups for draft layers using kv_cache_config.
Called from the model runner's initialize_metadata_builders.
"""
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
# Find which kv_cache_group the draft layers belong to
self.validate_same_kv_cache_group(kv_cache_config)
kv_cache_spec = None
for gid, group in enumerate(kv_cache_config.kv_cache_groups):
if self._draft_attn_layer_names & set(group.layer_names):
self.kv_cache_gid = gid
kv_cache_spec = group.kv_cache_spec
break
attention_groups: dict[tuple[str, str], AttentionGroup] = {}
if kv_cache_spec is not None:
for layer_name in self._draft_attn_layer_names:
attn_backend = all_attn_layers[layer_name].get_attn_backend()
backend_key = attn_backend.full_cls_name()
if backend_key not in attention_groups:
layer_kv_cache_spec = kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name
]
kernel_block_size = (
kernel_block_sizes[self.kv_cache_gid]
if kernel_block_sizes is not None
and self.kv_cache_gid < len(kernel_block_sizes)
else None
)
attn_group = AttentionGroup(
backend=attn_backend,
layer_names=[layer_name],
kv_cache_spec=layer_kv_cache_spec,
kv_cache_group_id=self.kv_cache_gid,
)
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_size=kernel_block_size,
)
attention_groups[backend_key] = attn_group
else:
attention_groups[backend_key].layer_names.append(layer_name)
self.draft_attn_groups = list(attention_groups.values())
def _determine_batch_execution_and_padding( def _determine_batch_execution_and_padding(
self, self,
num_tokens: int, num_tokens: int,
......
...@@ -1936,7 +1936,7 @@ class GPUModelRunner( ...@@ -1936,7 +1936,7 @@ class GPUModelRunner(
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.kv_cache_gid == kv_cache_gid:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
else: else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
...@@ -5559,6 +5559,14 @@ class GPUModelRunner( ...@@ -5559,6 +5559,14 @@ class GPUModelRunner(
# because some of them change the threshold at init time. # because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
# Initialize drafter attention backend
if self.speculative_config and (
self.speculative_config.use_eagle()
or self.speculative_config.uses_draft_model()
):
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes)
def _check_and_update_cudagraph_mode( def _check_and_update_cudagraph_mode(
self, self,
attention_backends: list[set[type[AttentionBackend]]], attention_backends: list[set[type[AttentionBackend]]],
...@@ -6079,15 +6087,11 @@ class GPUModelRunner( ...@@ -6079,15 +6087,11 @@ class GPUModelRunner(
kv_cache_config, kernel_block_sizes kv_cache_config, kernel_block_sizes
) )
if self.speculative_config and ( if (
self.speculative_config.use_eagle() self.speculative_config
or self.speculative_config.uses_draft_model() and self.speculative_config.uses_extract_hidden_states()
or self.speculative_config.uses_extract_hidden_states()
): ):
assert isinstance( assert isinstance(self.drafter, ExtractHiddenStatesProposer)
self.drafter,
EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer,
)
# validate all draft model layers belong to the same kv cache # validate all draft model layers belong to the same kv cache
# group # group
self.drafter.validate_same_kv_cache_group(kv_cache_config) self.drafter.validate_same_kv_cache_group(kv_cache_config)
......
...@@ -48,7 +48,7 @@ class AttentionGroup: ...@@ -48,7 +48,7 @@ class AttentionGroup:
self, self,
vllm_config, vllm_config,
device, device,
kernel_block_size: int | None, kernel_block_size: int | None = None,
num_metadata_builders: int = 1, num_metadata_builders: int = 1,
): ):
kv_cache_spec_builder = ( kv_cache_spec_builder = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment