Commit fcc9c9ea authored by luopl's avatar luopl
Browse files

feat:新增step3.5-mtp3功能

parent 9dc40d38
...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator ...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
...@@ -76,6 +77,10 @@ class SpeculativeConfig: ...@@ -76,6 +77,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered.""" `prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
...@@ -110,6 +115,11 @@ class SpeculativeConfig: ...@@ -110,6 +115,11 @@ class SpeculativeConfig:
which may only be supported by certain attention backends. This currently which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation.""" only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration # Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1) prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required """Maximum size of ngram token window when using Ngram proposer, required
...@@ -121,6 +131,12 @@ class SpeculativeConfig: ...@@ -121,6 +131,12 @@ class SpeculativeConfig:
speculative_token_tree: str | None = None speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine # required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model.""" """The configuration of the target model."""
...@@ -154,6 +170,10 @@ class SpeculativeConfig: ...@@ -154,6 +170,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than tokens with estimated probability (based on frequency counts) greater than
or equal to this value.""" or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -401,7 +421,11 @@ class SpeculativeConfig: ...@@ -401,7 +421,11 @@ class SpeculativeConfig:
MTPModelTypes MTPModelTypes
): ):
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: # if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning( logger.warning(
"Enabling num_speculative_tokens > 1 will run" "Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer" "multiple times of forward on same MTP layer"
...@@ -472,6 +496,17 @@ class SpeculativeConfig: ...@@ -472,6 +496,17 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None: if self.num_speculative_tokens is None:
# Default to max value defined in draft model config. # Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict self.num_speculative_tokens = n_predict
elif (
self.method == "mtp"
and self.enable_multi_layers_mtp
and self.num_speculative_tokens > n_predict
):
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self.num_speculative_tokens = n_predict
elif ( elif (
self.num_speculative_tokens > n_predict self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0 and self.num_speculative_tokens % n_predict != 0
...@@ -713,12 +748,31 @@ class SpeculativeConfig: ...@@ -713,12 +748,31 @@ class SpeculativeConfig:
f"errors during speculative decoding." f"errors during speculative decoding."
) )
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp") return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool: def uses_draft_model(self) -> bool:
return self.method == "draft_model" return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model model = None if method in ("ngram", "suffix") else self.draft_model_config.model
......
...@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel): ...@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context): def model_post_init(self, __context):
if not self.id: if not self.id:
self.id = f"msg_{int(time.time() * 1000)}" self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None
\ No newline at end of file
This diff is collapsed.
...@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing):
index = 0 index = 0
if ( if (
self._should_check_for_unstreamed_tool_arg_tokens( # self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output # delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
) )
and tool_parser # and tool_parser
): ):
latest_delta_len = 0 latest_delta_len = 0
if ( if (
...@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing):
latest_delta_len = len( latest_delta_len = len(
delta_message.tool_calls[0].function.arguments delta_message.tool_calls[0].function.arguments
) )
# get the expected call based on partial JSON # get the expected call based on partial JSON
# parsing which "autocompletes" the JSON # parsing which "autocompletes" the JSON.
expected_call = json.dumps( # Tool parsers (e.g. Qwen3Coder) store
tool_parser.prev_tool_call_arr[index].get( # arguments as a JSON string in
"arguments", {} # prev_tool_call_arr. Calling json.dumps()
), # on an already-serialized string would
ensure_ascii=False, # double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# expected_call = json.dumps(
# tool_parser.prev_tool_call_arr[index].get(
# "arguments", {}
# ),
# ensure_ascii=False,
# )
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
) )
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool
...@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
delta_message: DeltaMessage | None, delta_message: DeltaMessage | None,
output: CompletionOutput, output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool: ) -> bool:
""" """
Check to see if we should check for unstreamed tool arguments tokens. Check to see if we should check for unstreamed tool arguments tokens.
...@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0] and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None and delta_message.tool_calls[0].function.arguments is not None
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
) )
@staticmethod @staticmethod
......
...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple): ...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
""" """
Whether this batch has active LoRA adapters. Whether this batch has active LoRA adapters.
""" """
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
""" """
...@@ -191,7 +199,7 @@ class ForwardContext: ...@@ -191,7 +199,7 @@ class ForwardContext:
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
""" """
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata attention layer to its attention metadata
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch. for each microbatch.
......
...@@ -44,6 +44,23 @@ if TYPE_CHECKING: ...@@ -44,6 +44,23 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
......
...@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
shared_output=None,
routed_scaling_factor=None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
assert not self.is_monolithic assert not self.is_monolithic
......
...@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config) # self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer( self.mtp_block = Step3p5DecoderLayer(
vllm_config, vllm_config,
prefix=f"{prefix}.mtp_block", prefix=f"{prefix}.mtp_block",
...@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0, spec_step_index: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
assert inputs_embeds is not None if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
# assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
...@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): Step3p5AMultiTokenPredictorLayer( str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config, # vllm_config,
f"{prefix}.layers.{idx}", # f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
...@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None: # if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) # inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
inputs_embeds, inputs_embeds,
self.embed_tokens,
current_step_idx, current_step_idx,
) )
...@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor( logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) # mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
) )
return logits return logits
...@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module): ...@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".") name = name.replace(".transformer.", ".")
if "shared_head" in name: if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head") name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name: if "embed_tokens" in name:
assert ( assert (
hasattr(self.config, "num_nextn_predict_layers") hasattr(self.config, "num_nextn_predict_layers")
......
...@@ -118,6 +118,11 @@ class ToolParser: ...@@ -118,6 +118,11 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!" "AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
) )
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager: class ToolParserManager:
""" """
......
This diff is collapsed.
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
...@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo ...@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size( def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]: ) -> list[KVCacheGroupSpec]:
""" """
...@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups. memory per block is the same for all groups.
Args: Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns: Returns:
The generated KVCacheGroupSpecs The generated KVCacheGroupSpecs
...@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better # is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30 # strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10). # sw, where the group size should be 10).
min_num_layers = min([len(layers) for layers in same_type_layers.values()]) min_num_layers = min([len(layers) for layers in same_type_layers.values()]) #12
group_size = min_num_layers group_size = min_num_layers
max_num_layers = max([len(layers) for layers in same_type_layers.values()]) max_num_layers = max([len(layers) for layers in same_type_layers.values()]) #36
if max_num_layers < min_num_layers * 1.25: if max_num_layers < min_num_layers * 1.25:
# If the number of layers is not much larger than the minimum number of layers, # If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding # use the maximum number of layers as the group size to avoid too many padding
...@@ -1050,19 +1053,28 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1050,19 +1053,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100, num_padding_layers / len(layers) * 100,
) )
num_groups = cdiv(len(layers), group_size) num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have # for support multi layer mtp, we need to
# - stage 0: full.0, sw.0, sw.1 # make all mtp layers in the same group
# - stage 1: full.1, sw.2, sw.3 if (
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) vllm_config.speculative_config is not None
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because and vllm_config.speculative_config.enable_multi_layers_mtp
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) ):
# and it will be padded to (full.0, padding), (sw.0, sw.1), for i in range(0, len(layers), group_size):
# (padding, padding) to ensure the number of layers in each group is grouped_layers.append(layers[i : i + group_size])
# the same and will cause memory waste. else:
# To avoid this, we assign layers[i::num_groups] to the i-th group # In PP case, say if we have
# instead of layers[i * group_size: (i + 1) * group_size] # - stage 0: full.0, sw.0, sw.1
for i in range(num_groups): # - stage 1: full.1, sw.2, sw.3
grouped_layers.append(layers[i::num_groups]) # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
...@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups( ...@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups(
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups) group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size( page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups] [group.kv_cache_spec for group in kv_cache_groups]
) )
...@@ -1247,8 +1258,10 @@ def get_kv_cache_groups( ...@@ -1247,8 +1258,10 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers # have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page # into groups with the same number of layers, and thus same total page
# size. # size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) # return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config( def generate_scheduler_kv_cache_config(
kv_cache_configs: list[KVCacheConfig], kv_cache_configs: list[KVCacheConfig],
...@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len( ...@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len(
) )
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs( def get_kv_cache_configs(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]], kv_cache_specs: list[dict[str, KVCacheSpec]],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -57,6 +61,11 @@ class CudagraphDispatcher: ...@@ -57,6 +61,11 @@ class CudagraphDispatcher:
) )
self.keys_initialized = False self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
...@@ -64,6 +73,9 @@ class CudagraphDispatcher: ...@@ -64,6 +73,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size.""" """Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes capture_sizes = self.compilation_config.cudagraph_capture_sizes
assert capture_sizes is not None, (
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip( for end, start in zip(
capture_sizes + [max_size + 1], capture_sizes + [max_size + 1],
...@@ -92,8 +104,33 @@ class CudagraphDispatcher: ...@@ -92,8 +104,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes." "Use values from cudagraph_capture_sizes."
) )
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor( def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor: ) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len uniform_decode_query_len = self.uniform_decode_query_len
...@@ -111,6 +148,7 @@ class CudagraphDispatcher: ...@@ -111,6 +148,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs, num_reqs=num_reqs,
uniform=uniform_decode, uniform=uniform_decode,
has_lora=has_lora, has_lora=has_lora,
num_active_loras=num_active_loras,
) )
def add_cudagraph_key( def add_cudagraph_key(
...@@ -143,18 +181,27 @@ class CudagraphDispatcher: ...@@ -143,18 +181,27 @@ class CudagraphDispatcher:
lora_cases = [True] lora_cases = [True]
else: else:
lora_cases = [False] lora_cases = [False]
# Get LoRA cases to capture
# lora_cases = self._get_lora_cases()
self.captured_lora_counts = [
lora_count for lora_count in lora_cases if lora_count
]
# Note: we create all valid keys for cudagraph here but do not # Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy # guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered. # capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product( assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when mixed mode is enabled."
)
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases self.compilation_config.cudagraph_capture_sizes, lora_cases
): ):
self.add_cudagraph_key( self.add_cudagraph_key(
cudagraph_mode.mixed_mode(), cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor( self._create_padded_batch_descriptor(
bs, False, has_lora bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(), ).relax_for_mixed_batch_cudagraphs(),
) )
...@@ -168,15 +215,20 @@ class CudagraphDispatcher: ...@@ -168,15 +215,20 @@ class CudagraphDispatcher:
uniform_decode_query_len uniform_decode_query_len
* self.vllm_config.scheduler_config.max_num_seqs * self.vllm_config.scheduler_config.max_num_seqs
) )
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when full mode is enabled."
)
cudagraph_capture_sizes_for_decode = [ cudagraph_capture_sizes_for_decode = [
x x
for x in self.compilation_config.cudagraph_capture_sizes for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len if x <= max_num_tokens and x >= uniform_decode_query_len
] ]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): for bs, num_active_loras in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key( self.add_cudagraph_key(
CUDAGraphMode.FULL, CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora), self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
) )
self.keys_initialized = True self.keys_initialized = True
...@@ -199,14 +251,19 @@ class CudagraphDispatcher: ...@@ -199,14 +251,19 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len). length is uniform_decode_query_len).
has_lora: Whether LoRA is active. has_lora: Whether LoRA is active.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
disable_full: If True, skip FULL cudagraph checks and disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs) cascade attention that are not supported by full cudagraphs)
""" """
# allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if ( if (
not self.keys_initialized not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size or num_tokens > self.compilation_config.max_cudagraph_capture_size
# or allowed_modes <= {CUDAGraphMode.NONE}
): ):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1
...@@ -67,3 +67,41 @@ class SpecDecodeMetadata: ...@@ -67,3 +67,41 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices, bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices, logits_indices=logits_indices,
) )
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment