Commit fcc9c9ea authored by luopl's avatar luopl
Browse files

feat:新增step3.5-mtp3功能

parent 9dc40d38
......@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
......@@ -76,6 +77,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
......@@ -110,6 +115,11 @@ class SpeculativeConfig:
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required
......@@ -121,6 +131,12 @@ class SpeculativeConfig:
speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation.
"""
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
......@@ -154,6 +170,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -401,7 +421,11 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
# if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer"
......@@ -472,6 +496,17 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.method == "mtp"
and self.enable_multi_layers_mtp
and self.num_speculative_tokens > n_predict
):
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
......@@ -713,12 +748,31 @@ class SpeculativeConfig:
f"errors during speculative decoding."
)
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
......
......@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None
\ No newline at end of file
This diff is collapsed.
......@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing):
index = 0
if (
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
# self._should_check_for_unstreamed_tool_arg_tokens(
# delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
)
and tool_parser
# and tool_parser
):
latest_delta_len = 0
if (
......@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing):
latest_delta_len = len(
delta_message.tool_calls[0].function.arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# expected_call = json.dumps(
# tool_parser.prev_tool_call_arr[index].get(
# "arguments", {}
# ),
# ensure_ascii=False,
# )
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
),
ensure_ascii=False,
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
......@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing):
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
......@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
)
@staticmethod
......
......@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
......
......@@ -44,6 +44,23 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0
......
......@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
shared_output=None,
routed_scaling_factor=None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None
assert not self.is_monolithic
......
......@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
......@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
# assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
......@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config,
f"{prefix}.layers.{idx}",
# vllm_config,
# f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
......@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
......@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
# mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
)
return logits
......@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
......
......@@ -118,6 +118,11 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager:
"""
......
This diff is collapsed.
......@@ -7,6 +7,7 @@ import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload
from vllm import envs
......@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
......@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
......@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
min_num_layers = min([len(layers) for layers in same_type_layers.values()])
min_num_layers = min([len(layers) for layers in same_type_layers.values()]) #12
group_size = min_num_layers
max_num_layers = max([len(layers) for layers in same_type_layers.values()])
max_num_layers = max([len(layers) for layers in same_type_layers.values()]) #36
if max_num_layers < min_num_layers * 1.25:
# If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding
......@@ -1050,6 +1053,15 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100,
)
num_groups = cdiv(len(layers), group_size)
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i : i + group_size])
else:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
......@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups(
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
......@@ -1247,8 +1258,10 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
# return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config(
kv_cache_configs: list[KVCacheConfig],
......@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len(
)
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs(
vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__)
......@@ -57,6 +61,11 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE
......@@ -64,6 +73,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
assert capture_sizes is not None, (
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
......@@ -92,8 +104,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes."
)
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
......@@ -111,6 +148,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
......@@ -143,18 +181,27 @@ class CudagraphDispatcher:
lora_cases = [True]
else:
lora_cases = [False]
# Get LoRA cases to capture
# lora_cases = self._get_lora_cases()
self.captured_lora_counts = [
lora_count for lora_count in lora_cases if lora_count
]
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product(
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when mixed mode is enabled."
)
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, has_lora
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
)
......@@ -168,15 +215,20 @@ class CudagraphDispatcher:
uniform_decode_query_len
* self.vllm_config.scheduler_config.max_num_seqs
)
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when full mode is enabled."
)
cudagraph_capture_sizes_for_decode = [
x
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
for bs, num_active_loras in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora),
self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
)
self.keys_initialized = True
......@@ -199,14 +251,19 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
"""
# allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
# or allowed_modes <= {CUDAGraphMode.NONE}
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1
......@@ -67,3 +67,41 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -61,6 +61,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
......@@ -103,6 +110,8 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
......@@ -223,7 +232,45 @@ class InputBatch:
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
......@@ -464,6 +511,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
......@@ -662,6 +716,20 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
......@@ -784,6 +852,21 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[last_req_index]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment