Commit 94823af1 authored by laibao's avatar laibao
Browse files

Revert "feat:新增step3.5-mtp3功能"

This reverts commit a1f4d869.
parent a1f4d869
......@@ -8,7 +8,6 @@ from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
......@@ -77,10 +76,6 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
......@@ -115,11 +110,6 @@ class SpeculativeConfig:
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required
......@@ -131,12 +121,6 @@ class SpeculativeConfig:
speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation.
"""
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
......@@ -170,10 +154,6 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -421,11 +401,7 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
# if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
if self.num_speculative_tokens > 1:
logger.warning(
"Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer"
......@@ -496,17 +472,6 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.method == "mtp"
and self.enable_multi_layers_mtp
and self.num_speculative_tokens > n_predict
):
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
......@@ -748,31 +713,12 @@ class SpeculativeConfig:
f"errors during speculative decoding."
)
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
......
......@@ -160,32 +160,3 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None
\ No newline at end of file
......@@ -15,9 +15,6 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicContextManagement,
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicContentBlock,
AnthropicDelta,
AnthropicError,
......@@ -115,7 +112,6 @@ class AnthropicServingMessages(OpenAIServingChat):
# Handle complex content blocks
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
if block.type == "text" and block.text:
......@@ -127,8 +123,6 @@ class AnthropicServingMessages(OpenAIServingChat):
"image_url": {"url": block.source.get("data", "")},
}
)
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
tool_call = {
......@@ -163,9 +157,6 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
......@@ -306,116 +297,10 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
try:
class _ActiveBlockState:
def __init__(self) -> None:
self.content_block_index = 0
self.block_type: str | None = None
self.block_index: int | None = None
self.block_signature: str | None = None
self.signature_emitted: bool = False
self.tool_use_id: str | None = None
def reset(self) -> None:
self.block_type = None
self.block_index = None
self.block_signature = None
self.signature_emitted = False
self.tool_use_id = None
def start(self, block: AnthropicContentBlock) -> None:
self.block_type = block.type
self.block_index = self.content_block_index
if block.type == "thinking":
self.block_signature = uuid.uuid4().hex
self.signature_emitted = False
self.tool_use_id = None
elif block.type == "tool_use":
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = block.id
else:
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = None
first_item = True
finish_reason = None
# content_block_index = 0
# content_block_started = False
content_block_index = 0
active_block_type: str | None = None
active_block_index: int | None = None
active_block_signature: str | None = None
signature_emitted = False
active_tool_use_id: str | None = None
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
events: list[str] = []
if active_block_type is None:
return events
if (
active_block_type == "thinking"
and active_block_signature is not None
and not signature_emitted
):
chunk = AnthropicStreamEvent(
index=active_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=active_block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=active_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
active_block_type = None
active_block_index = None
active_block_signature = None
signature_emitted = False
active_tool_use_id = None
content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
active_block_type = block.type
active_block_index = content_block_index
if block.type == "thinking":
active_block_signature = uuid.uuid4().hex
signature_emitted = False
active_tool_use_id = None
elif block.type == "tool_use":
active_block_signature = None
signature_emitted = True
active_tool_use_id = block.id
else:
active_block_signature = None
signature_emitted = True
active_tool_use_id = None
return event
content_block_started = False
async for item in generator:
if item.startswith("data:"):
......@@ -441,8 +326,6 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
stop_reason=None,
stop_sequence=None,
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
......@@ -458,33 +341,13 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info
if len(origin_chunk.choices) == 0:
# if content_block_started:
# stop_chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_stop",
# )
# data = stop_chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_stop")
# stop_reason = self.stop_reason_map.get(
# finish_reason or "stop"
# )
# chunk = AnthropicStreamEvent(
# type="message_delta",
# delta=AnthropicDelta(stop_reason=stop_reason),
# usage=AnthropicUsage(
# input_tokens=origin_chunk.usage.prompt_tokens
# if origin_chunk.usage
# else 0,
# output_tokens=origin_chunk.usage.completion_tokens
# if origin_chunk.usage
# else 0,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "message_delta")
# continue
for event in stop_active_block():
yield event
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop")
stop_reason = self.stop_reason_map.get(
finish_reason or "stop"
)
......@@ -503,134 +366,29 @@ class AnthropicServingMessages(OpenAIServingChat):
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "message_delta")
continue
# =========================================================
if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason
# continue
continue
# content
# if origin_chunk.choices[0].delta.content is not None:
# if not content_block_started:
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_start",
# content_block=AnthropicContentBlock(
# type="text", text=""
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_start")
# content_block_started = True
# if origin_chunk.choices[0].delta.content == "":
# continue
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_delta",
# delta=AnthropicDelta(
# type="text_delta",
# text=origin_chunk.choices[0].delta.content,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_delta")
# continue
# tool calls
# elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
# elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
# tool_call = origin_chunk.choices[0].delta.tool_calls[0]
# if tool_call.id is not None:
# if content_block_started:
# stop_chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_stop",
# )
# data = stop_chunk.model_dump_json(
# exclude_unset=True
# )
# yield wrap_data_with_event(
# data, "content_block_stop"
# )
# content_block_started = False
# content_block_index += 1
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_start",
# content_block=AnthropicContentBlock(
# type="tool_use",
# id=tool_call.id,
# name=tool_call.function.name
# if tool_call.function
# else None,
# input={},
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_start")
# content_block_started = True
# else:
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_delta",
# delta=AnthropicDelta(
# type="input_json_delta",
# partial_json=tool_call.function.arguments
# if tool_call.function
# else None,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_delta")
# continue
# thinking / text content
reasoning_delta = origin_chunk.choices[0].delta.reasoning
if reasoning_delta is not None:
if reasoning_delta == "":
pass
else:
if active_block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="thinking", thinking=""
)
)
yield start_event
if origin_chunk.choices[0].delta.content is not None:
if not content_block_started:
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="thinking_delta",
thinking=reasoning_delta,
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="text", text=""
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
if origin_chunk.choices[0].delta.content is not None:
if origin_chunk.choices[0].delta.content == "":
pass
else:
if active_block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(type="text", text="")
)
yield start_event
continue
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="text_delta",
......@@ -639,82 +397,55 @@ class AnthropicServingMessages(OpenAIServingChat):
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
continue
# tool calls - process all tool calls in the delta
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
for tool_call in origin_chunk.choices[0].delta.tool_calls:
# tool calls
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
if tool_call.id is not None:
# Update mapping for incremental updates
tool_index_to_id[tool_call.index] = tool_call.id
# Only create new block if different tool call
# AND has a name
tool_name = (
tool_call.function.name
if tool_call.function
else None
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
if (
active_tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_name,
input={},
data = stop_chunk.model_dump_json(
exclude_unset=True
)
yield wrap_data_with_event(
data, "content_block_stop"
)
yield start_event
# Handle initial arguments if present
if (
tool_call.function
and tool_call.function.arguments
and active_tool_use_id == tool_call.id
):
content_block_started = False
content_block_index += 1
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name
if tool_call.function
else None,
input={},
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
else:
# Incremental update - use index to find tool_use_id
tool_use_id = tool_index_to_id.get(tool_call.index)
if (
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and active_tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
index=content_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
partial_json=tool_call.function.arguments
if tool_call.function
else None,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
yield wrap_data_with_event(data, "content_block_delta")
continue
else:
error_response = AnthropicStreamEvent(
......@@ -737,31 +468,3 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"
async def count_tokens(
self,
request: AnthropicCountTokensRequest,
raw_request: Request | None = None,
) -> AnthropicCountTokensResponse | ErrorResponse:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req = self._convert_anthropic_to_openai_request(request)
result = await self.render_chat_request(chat_req)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts
if "prompt_token_ids" in prompt
)
response = AnthropicCountTokensResponse(
input_tokens=input_tokens,
context_management=AnthropicContextManagement(
original_input_tokens=input_tokens
),
)
return response
\ No newline at end of file
......@@ -1239,13 +1239,10 @@ class OpenAIServingChat(OpenAIServing):
index = 0
if (
# self._should_check_for_unstreamed_tool_arg_tokens(
# delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output
)
# and tool_parser
and tool_parser
):
latest_delta_len = 0
if (
......@@ -1259,31 +1256,15 @@ class OpenAIServingChat(OpenAIServing):
latest_delta_len = len(
delta_message.tool_calls[0].function.arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# expected_call = json.dumps(
# tool_parser.prev_tool_call_arr[index].get(
# "arguments", {}
# ),
# ensure_ascii=False,
# )
args = tool_parser.prev_tool_call_arr[index].get(
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
),
ensure_ascii=False,
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
......@@ -1867,7 +1848,6 @@ class OpenAIServingChat(OpenAIServing):
self,
delta_message: DeltaMessage | None,
output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool:
"""
Check to see if we should check for unstreamed tool arguments tokens.
......@@ -1886,8 +1866,6 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
)
@staticmethod
......
......@@ -47,14 +47,6 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
......
......@@ -44,23 +44,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0
......
......@@ -52,8 +52,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer(
vllm_config,
prefix=f"{prefix}.mtp_block",
......@@ -65,13 +64,9 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
# assert inputs_embeds is not None
assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
......@@ -97,10 +92,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict(
{
str(idx): Step3p5AMultiTokenPredictorLayer(
# vllm_config,
# f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
vllm_config,
f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
......@@ -119,15 +112,14 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
self.embed_tokens,
current_step_idx,
)
......@@ -139,8 +131,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
# mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
......@@ -266,7 +257,6 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".")
if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name:
assert (
hasattr(self.config, "num_nextn_predict_layers")
......
......@@ -118,11 +118,6 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager:
"""
......
......@@ -2,14 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
# from xml.parsers.expat import ParserCreate
from xml.parsers.expat import ParserCreate
import regex as re
# from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
......@@ -26,1142 +25,1487 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
logger = init_logger(__name__)
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
class StreamingXMLToolCallParser:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
def __init__(self):
self.reset_streaming_state()
# Sentinel tokens for streaming mode
# Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.function_start_token: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_start_token: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns
self.tool_call_complete_regex = re.compile(
r"<tool_call>(.*?)</tool_call>", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function(?:=|\s+)?(.*?)</function>", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self.eos_token_id = getattr(self.model_tokenizer, "eos_token_id", None)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
def reset_streaming_state(self):
"""Reset streaming parsing state"""
self.deltas = []
# state for streaming
self.tool_call_index = 0
self.current_call_id = None
self.last_completed_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.streaming_buffer = ""
self.last_processed_pos = 0
self.text_content_buffer = ""
# state for preprocessing and deferred parsing
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
# recreate parser
self.parser = ParserCreate()
self.setup_parser()
def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage:
"""
Skip the remaining_call calculation in serving
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no complete elements
"""
return False
# Record delta count before processing
initial_delta_count = len(self.deltas)
def _reset_streaming_state(self):
"""Reset all streaming state for a new request."""
self._processed_length: int = 0 # Position of last processed character
self._tool_call_index: int = 0 # Number of tool calls processed so far
self.streaming_request = None # Current request being processed
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
self.streaming_buffer += xml_chunk
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
param_name,
func_name,
)
return param_value
found_elements = self._process_complete_xml_elements()
if found_elements:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try:
new_deltas = self.deltas[initial_delta_count:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if (
isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
self.current_call_id is not None
and self.function_end_token in xml_chunk
):
try:
return int(param_value)
except (ValueError, TypeError):
try:
float_value = float(param_value)
if float_value.is_integer():
return int(float_value)
except (ValueError, TypeError):
pass
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return int(literal_value)
if isinstance(literal_value, (int, float)):
return (
int(literal_value)
if float(literal_value).is_integer()
else literal_value
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close = any(
(
td.tool_calls
and any(
(
tc.function
and tc.id == self.current_call_id
and isinstance(tc.function.arguments, str)
and (tc.function.arguments in ("}", "{}"))
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
for tc in td.tool_calls
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (int, float)):
return (
float(literal_value)
if float(literal_value) - int(float(literal_value)) != 0
else int(float(literal_value))
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
normalized_value = param_value.strip().lower()
if normalized_value in ["true", "false"]:
return normalized_value == "true"
if normalized_value in ["1", "0"]:
return normalized_value == "1"
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
for td in new_deltas
)
return param_value
else:
if not has_function_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
if (
param_type in ["object", "array", "arr"]
or param_type.startswith("dict")
or param_type.startswith("list")
self.current_call_id is not None
and self.tool_call_end_token in xml_chunk
):
try:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (list, dict)):
return literal_value
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
has_toolcall_close = any(
(
td.tool_calls
and any(
(
tc.type == "function"
and tc.function
and tc.function.arguments == ""
and tc.id == self.current_call_id
)
return param_value
try:
literal_value = ast.literal_eval(param_value) # safer
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
if (
isinstance(literal_value, (list, dict, str, int, float, bool))
or literal_value is None
):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
for tc in td.tool_calls
)
return param_value
def _parse_parameters_fallback(
self,
parameters: str,
allowed_param_names: set[str] | None = None,
) -> list[tuple[str, str]]:
"""Fallback parser for malformed parameter tags."""
param_pairs: list[tuple[str, str]] = []
pos = 0
while True:
start = parameters.find(self.parameter_prefix, pos)
if start == -1:
break
name_start = start + len(self.parameter_prefix)
name_end = parameters.find(">", name_start)
if name_end == -1:
newline_idx = parameters.find("\n", name_start)
end_tag = parameters.find(self.parameter_end_token, name_start)
next_param = parameters.find(self.parameter_prefix, name_start)
candidates = [
idx for idx in [newline_idx, end_tag, next_param] if idx != -1
]
if not candidates:
break
name_end = min(candidates)
value_start = name_end
else:
value_start = name_end + 1
param_name = parameters[name_start:name_end].strip()
next_param = parameters.find(self.parameter_prefix, value_start)
end_tag = parameters.find(self.parameter_end_token, value_start)
if end_tag == -1 or (next_param != -1 and next_param < end_tag):
end = next_param if next_param != -1 else len(parameters)
pos = end
else:
end = end_tag
pos = end + len(self.parameter_end_token)
param_value = parameters[value_start:end]
if allowed_param_names is None or param_name in allowed_param_names:
param_pairs.append((param_name, param_value))
return param_pairs
def _is_valid_json_arguments(self, arguments: str) -> bool:
"""Check if arguments can be loaded as JSON."""
try:
json.loads(arguments)
except Exception:
return False
return True
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
# check empty function name
function_name = function_call_str[:end_index].strip()
if function_name.startswith("="):
function_name = function_name.lstrip("=").strip()
if not function_name or function_name.strip("'\"") == "":
logger.warning("Empty function name in tool call.")
return None
if function_name[0] in "\"'" and function_name[-1] == function_name[0]:
function_name = function_name[1:-1].strip()
if not function_name:
logger.warning("Empty function name in tool call.")
return None
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
match_texts = self.tool_call_parameter_regex.findall(parameters)
use_fallback = False
if match_texts:
for match_text in match_texts:
if self.parameter_prefix in match_text or ">" not in match_text:
use_fallback = True
break
else:
use_fallback = self.parameter_prefix in parameters
if use_fallback:
allowed_param_names = (
set(param_config.keys())
if isinstance(param_config, dict) and param_config
else None
)
param_pairs = self._parse_parameters_fallback(
parameters, allowed_param_names
for td in new_deltas
)
else:
param_pairs = []
for match_text in match_texts:
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
param_pairs.append((param_name, param_value))
for param_name, param_value in param_pairs:
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
try:
arguments = json.dumps(param_dict, ensure_ascii=False)
if not has_toolcall_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
self._end_element("tool_call")
except Exception as e:
logger.warning("Error in converting parameter value: %s", e)
return None
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=arguments),
logger.warning("Error with fallback parsing: %s", e)
# Merge newly generated deltas into single response
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
else:
# No complete elements, check if there's unoutput text content
if self.text_content_buffer and self.tool_call_index == 0:
# Has text content but no tool_call yet, output text content
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer to avoid duplicate output
self.text_content_buffer = ""
return text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if self.current_call_id is not None and (
self.function_end_token in xml_chunk
or self.tool_call_end_token in xml_chunk
):
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.function_end_token in xml_chunk and self.current_function_name:
self._end_element("function")
if self.tool_call_end_token in xml_chunk:
self._end_element("tool_call")
# Return the merged delta result generated by this fallback
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
raw_tool_calls = self.tool_call_complete_regex.findall(model_output)
# if no closed tool_call tags found, return empty list
if len(raw_tool_calls) == 0:
return []
raw_function_calls = []
for tool_call in raw_tool_calls:
function_matches = self.tool_call_function_regex.findall(tool_call)
raw_function_calls.extend(function_matches)
# No complete elements, return empty response
return DeltaMessage(content=None)
return raw_function_calls
def _escape_xml_special_chars(self, text: str) -> str:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes = {
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&apos;",
}
def _check_format(self, model_output: str) -> bool:
"""Check if model output contains properly formatted tool call.
for char, escape in xml_escapes.items():
text = text.replace(char, escape)
Requirements:
1. Must have closed tool_call tags (<tool_call>...</tool_call>)
2. Must have closed function tags (<function=...</function>)
3. If parameter tags exist, they must be closed and correct
return text
Returns True if the format is valid, False otherwise.
def _process_complete_xml_elements(self) -> bool:
"""
# Check 1: Must have closed tool_call tags
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return False
Process complete XML elements in buffer
# Check 2: Must have closed function tags within tool_call
has_valid_function = False
for tool_call_content in tool_call_matches:
function_matches = self.tool_call_function_regex.findall(tool_call_content)
if len(function_matches) > 0:
has_valid_function = True
# Check if there's an unclosed function tag
if (
self.tool_call_prefix in tool_call_content
and self.function_end_token not in tool_call_content
):
return False
Returns:
bool: Whether complete elements were found and processed
"""
found_any = False
if not has_valid_function:
return False
while self.last_processed_pos < len(self.streaming_buffer):
# Find next complete xml element
element, end_pos = self._find_next_complete_element(self.last_processed_pos)
if element is None:
# No complete element found, wait for more data
break
# Check 3: If parameter tags exist, they must be closed and correct
for tool_call_content in tool_call_matches:
# Count opening and closing parameter tags
param_open_count = tool_call_content.count(self.parameter_prefix)
param_close_count = tool_call_content.count(self.parameter_end_token)
# Check if this element should be skipped
if self._should_skip_element(element):
self.last_processed_pos = end_pos
continue
# If there are parameter tags, they must be balanced
if param_open_count > 0:
if param_open_count != param_close_count:
return False
# Check if all parameter tags are properly closed using regex
param_matches = self.tool_call_parameter_regex.findall(
tool_call_content
# Found complete XML element, process it
try:
preprocessed_element = self._preprocess_xml_chunk(element)
# Check if this is the first tool_call start
if (
(
preprocessed_element.strip().startswith("<tool_call>")
or preprocessed_element.strip().startswith("<function name=")
)
if len(param_matches) != param_open_count:
return False
return True
def _wrap_missing_tool_call_tags(self, model_output: str) -> str:
"""Wrap bare <function=...></function> blocks with <tool_call> tags."""
and self.tool_call_index == 0
) and self.text_content_buffer:
# First tool_call starts,
# output previously collected text content first
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer for potential subsequent text content
self.text_content_buffer = ""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if (
self.tool_call_prefix not in model_output
or self.function_end_token not in model_output
preprocessed_element.strip().startswith("<tool_call>")
and self.tool_call_index > 0
and self.current_call_id
and self.current_function_name
):
return model_output
def _wrap_bare_functions(text: str) -> str:
pos = 0
wrapped_parts: list[str] = []
while True:
func_idx = text.find(self.tool_call_prefix, pos)
if func_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx = text.find(self.function_end_token, func_idx)
if end_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx += len(self.function_end_token)
wrapped_parts.append(text[pos:func_idx])
wrapped_parts.append(self.tool_call_start_token)
wrapped_parts.append(text[func_idx:end_idx])
wrapped_parts.append(self.tool_call_end_token)
ws_idx = end_idx
while ws_idx < len(text) and text[ws_idx].isspace():
ws_idx += 1
if text.startswith(self.tool_call_end_token, ws_idx):
if ws_idx > end_idx:
wrapped_parts.append(text[end_idx:ws_idx])
pos = ws_idx + len(self.tool_call_end_token)
else:
pos = end_idx
return "".join(wrapped_parts)
tool_call_ranges = [
match.span()
for match in self.tool_call_complete_regex.finditer(model_output)
]
if not tool_call_ranges:
return _wrap_bare_functions(model_output)
wrapped_parts: list[str] = []
pos = 0
for start, end in tool_call_ranges:
if start < pos:
continue
wrapped_parts.append(_wrap_bare_functions(model_output[pos:start]))
wrapped_parts.append(model_output[start:end])
pos = end
wrapped_parts.append(_wrap_bare_functions(model_output[pos:]))
return "".join(wrapped_parts)
def _normalize_prev_arguments(self, args_value: Any) -> Any:
if isinstance(args_value, str):
try:
return json.loads(args_value)
except (TypeError, ValueError, json.JSONDecodeError):
return args_value
return args_value
def _update_prev_tool_call_state(self, tool_calls: list[ToolCall]) -> None:
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
for tool_call in tool_calls:
if not tool_call or not tool_call.function:
continue
args_value = tool_call.function.arguments
if isinstance(args_value, str):
args_json = args_value
elif args_value is None:
args_json = ""
else:
try:
args_json = json.dumps(args_value, ensure_ascii=False)
except (TypeError, ValueError):
args_json = str(args_value)
prev_args = self._normalize_prev_arguments(args_json)
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": prev_args,
}
# Reset parser state but preserve generated deltas
if self.current_param_name:
self._end_element("parameter")
if self.current_function_open:
self._end_element("function")
# Output final tool_call tail delta
final_delta = DeltaMessage(
role=None,
content=None,
reasoning_content=None,
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
try:
expected_args_json = json.dumps(prev_args, ensure_ascii=False)
except (TypeError, ValueError):
expected_args_json = args_json
],
)
self._emit_delta(final_delta)
# Reset XML parser and current call state
self._reset_xml_parser_after_tool_call()
# Parse preprocessed element
self.parser.Parse(preprocessed_element, False)
found_any = True
# Serving may subtract the latest delta length from
# streamed_args_for_tool to detect unstreamed suffixes. Since this
# parser emits full arguments at once, store expected+actual so
# the subtraction yields expected_args_json and no resend occurs.
self.streamed_args_for_tool.append(expected_args_json + args_json)
except Exception as e:
logger.warning("Error when parsing XML elements: %s", e)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
try:
origin_model_output = model_output
try:
# Fallback: handle outputs without <tool_call> wrapper.
origin_model_output = self._wrap_missing_tool_call_tags(
origin_model_output
)
model_output = origin_model_output
except Exception:
pass
# Update processed position
self.last_processed_pos = end_pos
# Use streaming-like approach: process position by position
valid_tool_calls = []
content_parts = []
processed_length = 0
return found_any
while processed_length < len(model_output):
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
model_output, processed_length
)
def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk = self._fix_missing_equals_in_function_tag(chunk)
# Case 1: No more tool calls - add remaining as content
if tool_start_idx == -1:
remaining = model_output[processed_length:]
if remaining:
content_parts.append(remaining)
break
for tag_type in ["parameter", "function"]:
pattern = f"<{tag_type}="
if pattern not in chunk:
continue
start_idx = chunk.find(pattern)
after_tag = chunk[start_idx:]
gt_pos = after_tag.find(">")
lt_pos = after_tag.find("<", len(pattern))
# Case 2: Content before tool call
if tool_start_idx > processed_length:
content_before = model_output[processed_length:tool_start_idx]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if processed_length > 0:
text_before = model_output[:processed_length]
# Skip if already well-formed
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content_before.strip() == ""
gt_pos != -1
and (lt_pos == -1 or gt_pos < lt_pos)
and pattern in after_tag[:gt_pos]
):
# Skip whitespace between tool calls
pass
else:
content_parts.append(content_before)
else:
content_parts.append(content_before)
continue
# Case 3: Try to find complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
model_output, tool_start_idx
# Extract tag name (stop at space, newline, or <)
content = chunk[start_idx + len(pattern) :]
end_pos = next(
(i for i, ch in enumerate(content) if ch in (" ", "\n", "<")),
len(content),
)
tag_name = content[:end_pos]
# If tool call is incomplete - add remaining as content and stop
if tool_end_idx == -1:
remaining = model_output[tool_start_idx:]
if remaining:
content_parts.append(remaining)
break
if not tag_name:
continue
# Extract and try to parse the complete tool call
tool_call_text = model_output[tool_start_idx:tool_end_idx]
parsed_result = self.extract_tool_calls_basic(tool_call_text, request)
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if tag_name.startswith(f"{tag_type}="):
tag_name = tag_name[len(tag_type) + 1 :]
# If parsing succeeded, record the tool call(s)
if parsed_result.tools_called and parsed_result.tool_calls:
valid_tool_calls.extend(parsed_result.tool_calls)
processed_length = tool_end_idx
else:
# Parsing failed - treat this tool call as content
content_parts.append(tool_call_text)
processed_length = tool_end_idx
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(valid_tool_calls)
# Remove trailing non-alphanumeric chars (keep - and _)
while tag_name and not (
tag_name[-1].isalnum() or tag_name[-1] in ("-", "_")
):
tag_name = tag_name[:-1]
# Combine content parts
content = "".join(content_parts) if content_parts else None
if not tag_name:
continue
return ExtractedToolCallInformation(
tools_called=(len(valid_tool_calls) > 0),
tool_calls=valid_tool_calls,
content=content if content else None,
)
# Validate parameter exists in tool definition
if tag_type == "parameter" and not self._validate_parameter_name(tag_name):
continue
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
# Apply fix
chunk = chunk.replace(
f"<{tag_type}={content[:end_pos]}", f"<{tag_type}={tag_name}>", 1
)
def extract_tool_calls_basic(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
model_output = self._wrap_missing_tool_call_tags(model_output)
# Quick check to avoid unnecessary processing
if not self._check_format(model_output):
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return chunk
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _fix_missing_equals_in_function_tag(self, chunk: str) -> str:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if "<function=" in chunk:
return chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1 = r"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1 = re.search(pattern1, chunk)
if match1:
func_name = match1.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match1.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2 = r"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2 = re.search(pattern2, chunk)
if match2:
func_name = match2.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match2.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
return chunk
def _validate_function_name(self, func_name: str) -> bool:
"""Check if function name exists in tool definitions"""
if not self.tools:
return False
tool_calls: list[ToolCall] = []
for function_call_str in function_calls:
tool_call = self._parse_xml_function_call(
function_call_str, request.tools
)
if tool_call:
tool_calls.append(tool_call)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
for tool_call in tool_calls:
for tool in self.tools:
if (
not tool_call.function
or tool_call.function.arguments is None
or not self._is_valid_json_arguments(tool_call.function.arguments)
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == func_name
):
logger.warning(
"Invalid JSON arguments in tool call, falling back to content."
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return True
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(tool_calls)
return False
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content = model_output[:content_index] # .rstrip()
def _validate_parameter_name(self, param_name: str) -> bool:
"""Check if parameter exists in current function's tool definition"""
if not self.tools or not self.current_function_name:
return True
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
for tool in self.tools:
if (
hasattr(tool, "type")
and tool.type == "function"
and hasattr(tool, "function")
and hasattr(tool.function, "name")
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return True
params = tool.function.parameters
if isinstance(params, dict):
properties = params.get("properties", params)
return param_name in properties
break
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return True
def _find_first_complete_tool_call_end(self, text: str, start_pos: int = 0) -> int:
"""Find the end position of the first complete tool call.
def _should_skip_element(self, element: str) -> bool:
"""
Determine whether an element should be skipped
Args:
text: Text to search in
start_pos: Position to start searching from
element: Element to evaluate
Returns:
Position after the first </tool_call> tag, or -1 if incomplete
Example:
"<tool_call>...</tool_call>..." returns position after </tool_call>
bool: True means should skip, False means should process
"""
# Find tool call start
start_idx = text.find(self.tool_call_start_token, start_pos)
if start_idx == -1:
return -1
# Find matching end token
end_idx = text.find(
self.tool_call_end_token, start_idx + len(self.tool_call_start_token)
)
if end_idx == -1:
return -1 # Incomplete tool call
# If it's a tool_call XML tag, don't skip
if (
element.startswith(self.tool_call_start_token)
or element.startswith(self.function_start_token)
or element.startswith(self.parameter_start_token)
):
return False
# Return position after end token
return end_idx + len(self.tool_call_end_token)
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if self.current_call_id is None and element:
# Collect text content to buffer
self.text_content_buffer += element
return True # Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if self.current_call_id is not None:
return False
# Skip blank content
return not element
def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int:
"""Find the start position of next tool call.
def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]:
"""
Find next complete XML element from specified position
Args:
text: Text to search in
start_pos: Position to start searching from
start_pos: Position to start searching
Returns:
Position of <tool_call> token, or -1 if not found
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
"""
return text.find(self.tool_call_start_token, start_pos)
buffer = self.streaming_buffer[start_pos:]
def _extract_content_between_tool_calls_list(self, text: str) -> list[str]:
"""Extract content segments after each tool call.
if not buffer:
return None, start_pos
For n tool calls, returns n segments where segment[i] is the content
after tool_call[i] (before tool_call[i+1] or at the end).
if buffer.startswith("<"):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param = (
buffer.startswith("<parameter=") and ">" not in buffer.split("\n")[0]
)
is_incomplete_func = (
buffer.startswith("<function=") and ">" not in buffer.split("\n")[0]
)
Empty or whitespace-only segments are represented as empty string "".
if is_incomplete_param or is_incomplete_func:
# Find the corresponding closing tag
tag_type = "parameter" if is_incomplete_param else "function"
closing_tag = f"</{tag_type}>"
closing_pos = buffer.find(closing_tag)
if closing_pos != -1:
# Found closing tag, return complete element including closing tag
complete_element = buffer[: closing_pos + len(closing_tag)]
return complete_element, start_pos + closing_pos + len(closing_tag)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end = buffer.find("<", 1)
tag_end2 = buffer.find(">", 1)
if tag_end != -1 and tag_end2 != -1:
# Next nearest is <
if tag_end < tag_end2:
return buffer[:tag_end], start_pos + tag_end
# Next nearest is >, means found XML element
else:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
elif tag_end != -1:
return buffer[:tag_end], start_pos + tag_end
elif tag_end2 != -1:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if self.current_call_id is None:
# Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data
return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else:
# Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer)
else:
# When parsing tool calls,
# wait for more data to get complete tag
return None, start_pos
else:
# Find text content (until next < or buffer end)
next_tag_pos = buffer.find("<")
if next_tag_pos != -1:
# Found text content
text_content = buffer[:next_tag_pos]
return text_content, start_pos + next_tag_pos
else:
# Buffer end is all text, process
# (no longer wait for more data)
remaining = buffer
return remaining, start_pos + len(remaining)
def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Args:
text: Text containing tool calls
initial_count: Delta count before processing
Returns:
List of content segments (one per tool call)
Merged DeltaMessage containing all newly generated delta information
"""
content_segments = []
pos = 0
if len(self.deltas) <= initial_count:
return DeltaMessage(content=None)
while True:
# Find end of current tool call
end_pos = text.find(self.tool_call_end_token, pos)
if end_pos == -1:
break
# Get newly generated deltas
new_deltas = self.deltas[initial_count:]
# Move past the end token
end_pos += len(self.tool_call_end_token)
if len(new_deltas) == 1:
# Only one new delta, return directly
return new_deltas[0]
# Find start of next tool call
next_start = self._find_tool_call_start(text, end_pos)
# Merge multiple new deltas
merged_tool_calls: list[DeltaToolCall] = []
merged_content: str = ""
# Extract content between current end and next start (or text end)
content = text[end_pos:next_start] if next_start != -1 else text[end_pos:]
# Store content (empty string if whitespace-only)
content_segments.append(content if content.strip() else "")
if next_start == -1:
for delta in new_deltas:
if delta.content:
merged_content += delta.content
if delta.tool_calls:
# For tool_calls, we need to intelligently merge arguments
for tool_call in delta.tool_calls:
# Find if there's already a tool_call with the same call_id
existing_call = None
for existing in merged_tool_calls:
if existing.id == tool_call.id:
existing_call = existing
break
pos = next_start
return content_segments
if existing_call and existing_call.function:
# Merge to existing tool_call
if tool_call.function and tool_call.function.name:
existing_call.function.name = tool_call.function.name
if (
tool_call.function
and tool_call.function.arguments is not None
):
if existing_call.function.arguments is None:
existing_call.function.arguments = ""
# For streaming JSON parameters,
# simply concatenate in order
new_args = tool_call.function.arguments
existing_call.function.arguments += new_args
if tool_call.type:
existing_call.type = tool_call.type
else:
# Add new tool_call
merged_tool_calls.append(tool_call)
def _convert_tool_calls_to_deltas(
self, tool_calls: list[ToolCall], starting_index: int = 0
) -> list[DeltaToolCall]:
"""Convert complete ToolCall list to DeltaToolCall list.
return DeltaMessage(
content=merged_content if merged_content else None,
tool_calls=merged_tool_calls,
)
Returns complete tool calls without splitting into fragments.
def _preprocess_xml_chunk(self, chunk: str) -> str:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
Args:
tool_calls: List of tool calls to convert
starting_index: Starting index for tool calls (default 0)
chunk: Original XML chunk
Returns:
List of DeltaToolCall with complete arguments
Processed XML chunk
"""
delta_tool_calls = []
for i, tool_call in enumerate[ToolCall](tool_calls):
index = starting_index + i
tool_id = self._generate_tool_call_id()
# Create complete DeltaToolCall with full arguments
delta_tool_calls.append(
# Check if this is a tool_call related element
is_tool_call = False
if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
self.tool_call_end_token
):
is_tool_call = True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if (
chunk.startswith(self.function_start_token)
or chunk.startswith(self.function_end_token)
or chunk.startswith("<function ")
or re.match(r"^<function[a-zA-Z_]", chunk)
): # <functionXXX without space or =
is_tool_call = True
if chunk.startswith(self.parameter_start_token) or chunk.startswith(
self.parameter_end_token
):
is_tool_call = True
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if (
self.current_call_id is not None
or chunk.startswith("<function")
or chunk.startswith("<parameter")
):
chunk = self._fix_incomplete_tag_in_chunk(chunk)
# Handle <function=name> format -> <function name="name">
processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk)
# Handle <parameter=name> format -> <parameter name="name">
processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed)
original_chunk = chunk
# If in parameter value accumulation mode
if self._pre_inside_parameter:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if processed.startswith("</parameter>"):
body_text = self._pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self.defer_current_parameter = True
self.deferred_param_raw_value = body_text
# Clean up state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
safe_text = self._escape_xml_special_chars(body_text)
return f"{safe_text}</parameter>"
else:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if self._pre_param_buffer == "":
# Get current parameter type
param_type = (
self._get_param_type(self._pre_current_param_name)
if self._pre_current_param_name
else "string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type = param_type in ["object"]
is_complex_type = (
param_type in ["array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint = (
("[" in original_chunk)
or ("{" in original_chunk)
or ("(" in original_chunk)
)
# Determine if deferred parsing is needed
need_defer = False
if is_complex_type:
# Complex type, always need deferred parsing
need_defer = True
elif (
is_object_type
and has_container_hint
and ("'" in original_chunk)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer = True
if not need_defer:
# No need for deferred parsing,
# exit parameter mode directly
self._pre_inside_parameter = False
return self._escape_xml_special_chars(original_chunk)
self._pre_param_buffer += original_chunk
return ""
# Parameter start: enable accumulation
if processed.startswith("<parameter name="):
m = re.match(r'<parameter name="([^"]+)">', processed)
if m:
self._pre_current_param_name = m.group(1)
self._pre_inside_parameter = True
self._pre_param_buffer = ""
return processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if not is_tool_call:
processed = self._escape_xml_special_chars(processed)
return processed
def _emit_delta(self, delta: DeltaMessage):
"""Emit Delta response (streaming output)"""
self.deltas.append(delta)
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
"""
# First close unclosed parameters
if self.current_param_name:
self._end_element("parameter")
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if incoming_tag in ("function", "tool_call") and self.current_function_name:
self._end_element("function")
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if incoming_tag == "tool_call" and self.current_call_id:
self._end_element("tool_call")
def _start_element(self, name: str, attrs: dict[str, str]):
"""Handle XML start element events"""
if name == "root":
return
if name == "tool_call":
# Before opening new tool_call,
# automatically complete previous unclosed tags
self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {}
self.current_call_id = make_tool_call_id()
self.current_param_is_first = True
self.tool_call_index += 1
elif name.startswith("function") or (name == "function"):
# If missing tool_call, manually complete
if not self.current_call_id:
self._start_element("tool_call", {})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self._auto_close_open_parameter_if_needed("function")
function_name = self._extract_function_name(name, attrs)
self.current_function_name = function_name
self.current_function_open = True
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=index,
id=tool_id,
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
name=function_name, arguments=""
),
)
]
)
self._emit_delta(delta)
elif name.startswith("parameter") or (name == "parameter"):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self._auto_close_open_parameter_if_needed("parameter")
param_name = self._extract_parameter_name(name, attrs)
self.current_param_name = param_name
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False # Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if param_name:
if not self.parameters:
# First parameter
# start JSON, only output parameter name and colon
json_start = f'{{"{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_start
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = True
else:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue = f', "{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_continue
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = False
def _char_data(self, data: str):
"""Handle XML character data events"""
if data and self.current_param_name:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if self.defer_current_parameter:
original_data = data
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
return
param_type = self._get_param_type(self.current_param_name)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if not self.current_param_value and data.startswith("\n"):
data = data[1:]
# Output start quote for string type (if not already output)
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
and not self.start_quote_emitted
):
quote_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(quote_delta)
self.start_quote_emitted = True
if not data:
return
original_data = data
# Delay output of trailing newline
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
# convert parameter value by param_type
converted_value = self._convert_param_value(
self.current_param_value, param_type
)
output_data = self._convert_for_json_streaming(converted_value, param_type)
return delta_tool_calls
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
delta_data = output_data[len(self.current_param_value_converted) :]
self.current_param_value_converted = output_data
# Check for EOS token
has_eos = (
self.eos_token_id is not None
and delta_token_ids
and self.eos_token_id in delta_token_ids
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=delta_data),
)
# If no delta text, check if we need to return empty delta for finish_reason
if not delta_text and not has_eos:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
]
)
self._emit_delta(delta)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
return None
# Process all available content
accumulated_deltas: list[DeltaMessage] = []
def _end_element(self, name: str):
"""Handle XML end element events"""
while self._has_unprocessed_content(current_text):
# Try to process next chunk (content or tool call)
delta = self._process_next_chunk(current_text)
if name == "root":
return
if delta is None:
# Cannot proceed further, need more tokens
break
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if (
name.startswith("function") or name == "function" or name == "tool_call"
) and self.current_param_name:
self._auto_close_open_parameter_if_needed()
# Accumulate deltas
if isinstance(delta, list):
accumulated_deltas.extend(delta)
if (
name.startswith("parameter") or name == "parameter"
) and self.current_param_name:
# End current parameter
param_name = self.current_param_name
param_value = self.current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if self.defer_current_parameter:
raw_text = (
self.deferred_param_raw_value
if self.deferred_param_raw_value
else param_value
)
parsed_value = None
output_arguments = None
try:
# If previously delayed trailing newline,
# add it back before parsing
if self.should_emit_end_newline:
raw_for_parse = raw_text + "\n"
else:
accumulated_deltas.append(delta)
# Handle EOS: flush any remaining incomplete tool calls as content
if has_eos:
remaining_delta = self._flush_remaining_content(current_text)
if remaining_delta:
accumulated_deltas.append(remaining_delta)
# If no remaining content but we have tool calls, return empty delta
elif len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
accumulated_deltas.append(DeltaMessage(content=""))
raw_for_parse = raw_text
parsed_value = ast.literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
output_arguments = json.dumps(raw_text, ensure_ascii=False)
parsed_value = raw_text
# Return results
return self._format_delta_result(accumulated_deltas)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=output_arguments
),
)
]
)
self._emit_delta(delta)
def _has_unprocessed_content(self, current_text: str) -> bool:
"""Check if there's unprocessed content in the buffer."""
return self._processed_length < len(current_text)
# Clean up and store
self.should_emit_end_newline = False
self.parameters[param_name] = parsed_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
return
def _process_next_chunk(
self, current_text: str
) -> DeltaMessage | list[DeltaMessage] | None:
"""Process next chunk: either regular content or a complete tool call.
param_type = self._get_param_type(param_name)
Args:
current_text: Current accumulated text
# convert complete parameter value by param_type
converted_value = self._convert_param_value(param_value, param_type)
Returns:
- DeltaMessage or list of DeltaMessage if processed successfully
- None if cannot proceed (need more tokens)
"""
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
current_text, self._processed_length
# Decide whether to add end quote based on parameter type
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# For empty string parameters, need special handling
if not param_value and not self.start_quote_emitted:
# No start quote output,
# directly output complete empty string
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='""'),
)
# Case 1: No tool call found - return remaining content
if tool_start_idx == -1:
return self._process_content(
current_text, self._processed_length, len(current_text)
]
)
# Case 2: Content before tool call
if tool_start_idx > self._processed_length:
return self._process_content(
current_text, self._processed_length, tool_start_idx
self._emit_delta(delta)
else:
# Non-empty parameter value, output end quote
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(delta)
self.should_emit_end_newline = False
# Store converted value
self.parameters[param_name] = converted_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
elif name.startswith("function") or name == "function":
# if there are parameters, close JSON object
if self.parameters:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="}"),
)
]
)
self._emit_delta(delta)
# return empty object
else:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="{}"),
)
]
)
self._emit_delta(delta)
self.current_function_open = False
self.current_function_name = (
None # Clear function name to prevent duplicate closing
)
# Case 3: Tool call at current position
# Find end of the first complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
current_text, tool_start_idx
elif name == "tool_call":
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if self.current_function_open:
# If there are still unclosed parameters, close them first
if self.current_param_name:
self._end_element("parameter")
# Close function, ensure output '}' or '{}'
self._end_element("function")
# Final Delta
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
]
)
self._emit_delta(delta)
if tool_end_idx == -1:
# Tool call incomplete, wait for more tokens
return None
# Check if there's text content to output (between tool_calls)
if self.text_content_buffer.strip():
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Process complete tool call
return self._process_complete_tool_calls(
current_text, tool_start_idx, tool_end_idx
)
self._reset_xml_parser_after_tool_call()
def _process_content(
self, current_text: str, start_pos: int, end_pos: int
) -> DeltaMessage | None:
"""Process regular content (non-tool-call text).
def setup_parser(self):
"""Set up XML parser event handlers"""
self.parser.buffer_text = True
self.parser.StartElementHandler = self._start_element
self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data
Args:
current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
"""Set tool configuration information"""
self.tools = tools
Returns:
DeltaMessage with content if non-empty
"""
if start_pos >= end_pos:
return None
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract function name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
content = current_text[start_pos:end_pos]
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "function":
return parts[1]
# Check if we're between tool calls - skip whitespace
if start_pos > 0:
# Check if text before start_pos ends with </tool_call>
text_before = current_text[:start_pos]
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content.strip() == ""
):
# We just ended a tool call, skip whitespace between tool calls
self._processed_length = end_pos
return None
# Return content if non-empty
if content:
self._processed_length = end_pos
return DeltaMessage(content=content)
def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract parameter name from various formats"""
if attrs and "name" in attrs:
return attrs["name"]
# Mark as processed even if empty
self._processed_length = end_pos
return None
if "=" in name:
parts = name.split("=", 1)
if len(parts) == 2 and parts[0] == "parameter":
return parts[1]
def _flush_remaining_content(self, current_text: str) -> DeltaMessage | None:
"""Flush any remaining unprocessed content as regular content.
return None
def _get_param_type(self, param_name: str) -> str:
"""Get parameter type based on tool configuration, defaults to string
Args:
current_text: Current accumulated text
param_name: Parameter name
Used when EOS token is encountered to handle incomplete tool calls.
Returns:
Parameter type
"""
if not self._has_unprocessed_content(current_text):
return None
remaining = current_text[self._processed_length :]
if remaining:
self._processed_length = len(current_text)
return DeltaMessage(content=remaining)
if not self.tools or not self.current_function_name:
return "string"
self._processed_length = len(current_text)
return None
for tool in self.tools:
if not hasattr(tool, "type") or not (
hasattr(tool, "function") and hasattr(tool.function, "name")
):
continue
if (
tool.type == "function"
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return "string"
params = tool.function.parameters
if isinstance(params, dict) and "properties" in params:
properties = params["properties"]
if param_name in properties and isinstance(
properties[param_name], dict
):
return self.repair_param_type(
str(properties[param_name].get("type", "string"))
)
elif isinstance(params, dict) and param_name in params:
param_config = params[param_name]
if isinstance(param_config, dict):
return self.repair_param_type(
str(param_config.get("type", "string"))
)
break
return "string"
def _format_delta_result(self, deltas: list[DeltaMessage]) -> DeltaMessage | None:
"""Format delta result for return.
def repair_param_type(self, param_type: str) -> str:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
Merges all deltas into a single DeltaMessage.
Returns:
Repaired parameter type
"""
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
or param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
or param_type.startswith("num")
or param_type.startswith("float")
or param_type in ["boolean", "bool", "binary"]
or (
param_type in ["object", "array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
):
return param_type
else:
return "string"
def _convert_param_value(self, param_value: str, param_type: str) -> Any:
"""Convert value based on parameter type
Args:
deltas: List of delta messages
param_value: Parameter value
param_type: Parameter type
Returns:
- None if empty
- Single merged DeltaMessage with all content and tool_calls
Converted value
"""
if not deltas:
if param_value.lower() == "null":
return None
if len(deltas) == 1:
return deltas[0]
# Merge multiple deltas into one
merged_content_parts = []
merged_tool_calls = []
param_type = param_type.strip().lower()
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
return int(param_value)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not an integer, degenerating to string.",
param_value,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value: float = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not a float, degenerating to string.",
param_value,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
return param_value == "true"
else:
return param_value
for delta in deltas:
if delta.content:
merged_content_parts.append(delta.content)
if delta.tool_calls:
merged_tool_calls.extend(delta.tool_calls)
def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
# Create merged DeltaMessage
merged_content = "".join(merged_content_parts) if merged_content_parts else None
Returns:
Converted string for streaming output
"""
# Check if value is empty, but exclude numeric 0
if converted_value is None or converted_value == "":
return ""
# Build kwargs - only include tool_calls if non-empty
kwargs: dict[str, Any] = {"content": merged_content}
if merged_tool_calls:
kwargs["tool_calls"] = merged_tool_calls
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# String type, remove double quotes
return json.dumps(converted_value, ensure_ascii=False)[1:-1]
else:
# Non-string type, return complete JSON string
if not isinstance(converted_value, str):
return json.dumps(converted_value, ensure_ascii=False)
else:
return converted_value
return DeltaMessage(**kwargs)
def _reset_xml_parser_after_tool_call(self):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
def _process_complete_tool_calls(
self, current_text: str, start_pos: int, end_pos: int
) -> list[DeltaMessage] | None:
"""Process complete tool calls and convert to delta sequence.
# recreate XML parser
self.parser = ParserCreate()
self.setup_parser()
# Reset current tool_call state
if self.current_call_id:
self.last_completed_call_id = self.current_call_id
self.current_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.text_content_buffer = ""
# Reset preprocessing and deferred parsing state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
@ToolParserManager.register_module("step3p5")
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
Args:
current_text: Current accumulated text
start_pos: Start position (should be at <tool_call>)
end_pos: End position (after </tool_call>)
# Add missing attributes for compatibility with serving_chat.py
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
Returns:
List of DeltaMessage if successful, None otherwise
"""
try:
# Extract text segment containing complete tool call(s)
text_to_parse = current_text[start_pos:end_pos]
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
# Parse using non-streaming method
result = self.extract_tool_calls_basic(
text_to_parse, self.streaming_request
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls:
return ExtractedToolCallInformation(
tool_calls=[],
tools_called=False,
content=result.content,
)
else:
tool_calls = []
for tool_call in result.tool_calls:
if tool_call.function and tool_call.function.name:
tool_calls.append(
ToolCall(
id=tool_call.id,
type=tool_call.type,
function=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
# Case 1: Successfully parsed tool calls
if result.tools_called and result.tool_calls:
# Note: Due to _find_first_complete_tool_call_end, we typically
# process only one tool call at a time
# but we can also process multiple tool calls below
deltas = self._build_tool_call_deltas(result.tool_calls, text_to_parse)
self._update_state_after_tool_calls(result.tool_calls, end_pos)
return deltas if deltas else None
# Update tool call tracking arrays for compatibility
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Case 2: Parsing failed - treat as regular content
self._processed_length = end_pos
return [DeltaMessage(content=text_to_parse)]
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
except Exception as e:
# Exception during parsing - treat as content
logger.debug("Failed to parse tool calls: %s, treating as content", e)
self._processed_length = end_pos
failed_text = current_text[start_pos:end_pos]
return [DeltaMessage(content=failed_text)] if failed_text else None
# Update tool call information
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
def _build_tool_call_deltas(
self, tool_calls: list[ToolCall], parsed_text: str
) -> list[DeltaMessage]:
"""Build delta messages from parsed tool calls with interleaved content.
# Update streamed arguments
if tool_call.function.arguments:
self.streamed_args_for_tool[tool_index] = (
tool_call.function.arguments
)
Args:
tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
return ExtractedToolCallInformation(
tool_calls=tool_calls,
tools_called=len(tool_calls) > 0,
content=result.content,
)
Returns:
List of DeltaMessage with tool calls and content interleaved
"""
# Extract content segments between tool calls
content_segments = self._extract_content_between_tool_calls_list(parsed_text)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
if not previous_text:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if not delta_text and delta_token_ids:
open_calls = current_text.count(
self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token)
if (
open_calls == 0
and self.parser.tool_call_index > 0
or not self.parser.tool_call_index
and current_text
):
return DeltaMessage(content="")
return None
# Convert all tool calls to DeltaToolCall list
delta_tool_calls = self._convert_tool_calls_to_deltas(
tool_calls, self._tool_call_index
# Parse the delta text and get the result
result = self.parser.parse_single_streaming_chunks(delta_text)
# Update tool call tracking arrays based on incremental parsing results
if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Merge all content segments into a single string
merged_content = "".join(content_segments)
# Return a single DeltaMessage with all tool calls and content
# Build kwargs - only include non-empty fields
kwargs: dict[str, Any] = {}
if merged_content:
kwargs["content"] = merged_content
if delta_tool_calls:
kwargs["tool_calls"] = delta_tool_calls
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Only return DeltaMessage if we have content or tool_calls
if kwargs:
return [DeltaMessage(**kwargs)]
else:
return []
# Update tool name if provided
if tool_call.function.name:
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
def _update_state_after_tool_calls(
self, tool_calls: list[ToolCall], end_pos: int
) -> None:
"""Update internal state after processing tool calls.
# Update arguments incrementally
if tool_call.function.arguments is not None:
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result
Args:
tool_calls: List of processed tool calls
end_pos: End position in buffer
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
# Update processed position
self._processed_length = end_pos
# Update tool call index
self._tool_call_index += len(tool_calls)
# Update prev_tool_call_arr for finish_reason
self._update_prev_tool_call_state(tool_calls)
\ No newline at end of file
Skip the remaining_call calculation in serving_chat
"""
return False
......@@ -7,7 +7,6 @@ import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload
from vllm import envs
......@@ -948,7 +947,6 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
......@@ -1009,7 +1007,6 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
......@@ -1033,9 +1030,9 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
min_num_layers = min([len(layers) for layers in same_type_layers.values()]) #12
min_num_layers = min([len(layers) for layers in same_type_layers.values()])
group_size = min_num_layers
max_num_layers = max([len(layers) for layers in same_type_layers.values()]) #36
max_num_layers = max([len(layers) for layers in same_type_layers.values()])
if max_num_layers < min_num_layers * 1.25:
# If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding
......@@ -1053,15 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100,
)
num_groups = cdiv(len(layers), group_size)
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.enable_multi_layers_mtp
):
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i : i + group_size])
else:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
......@@ -1132,6 +1120,7 @@ def get_kv_cache_config_from_groups(
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
......@@ -1258,10 +1247,8 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
# return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config(
kv_cache_configs: list[KVCacheConfig],
......@@ -1464,42 +1451,6 @@ def _auto_fit_max_model_len(
)
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs(
vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__)
......@@ -61,11 +57,6 @@ class CudagraphDispatcher:
)
self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE
......@@ -73,9 +64,6 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
assert capture_sizes is not None, (
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
......@@ -104,33 +92,8 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes."
)
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
......@@ -148,7 +111,6 @@ class CudagraphDispatcher:
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
def add_cudagraph_key(
......@@ -181,27 +143,18 @@ class CudagraphDispatcher:
lora_cases = [True]
else:
lora_cases = [False]
# Get LoRA cases to capture
# lora_cases = self._get_lora_cases()
self.captured_lora_counts = [
lora_count for lora_count in lora_cases if lora_count
]
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when mixed mode is enabled."
)
for bs, num_active_loras in product(
for bs, has_lora in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
bs, False, has_lora
).relax_for_mixed_batch_cudagraphs(),
)
......@@ -215,20 +168,15 @@ class CudagraphDispatcher:
uniform_decode_query_len
* self.vllm_config.scheduler_config.max_num_seqs
)
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when full mode is enabled."
)
cudagraph_capture_sizes_for_decode = [
x
for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len
]
for bs, num_active_loras in product(cudagraph_capture_sizes_for_decode, lora_cases):
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
self._create_padded_batch_descriptor(bs, True, has_lora),
)
self.keys_initialized = True
......@@ -251,19 +199,14 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
"""
# allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if (
not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size
# or allowed_modes <= {CUDAGraphMode.NONE}
):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
......@@ -3,7 +3,6 @@
import ast
from dataclasses import replace
from importlib.util import find_spec
from typing import Any, cast
import numpy as np
import torch
......@@ -38,21 +37,17 @@ from vllm.v1.attention.backends.tree_attn import (
)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.utils import (
extend_all_queries_by_N,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel,
)
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
......@@ -80,33 +75,11 @@ class SpecDecodeBaseProposer:
self.max_model_len = vllm_config.model_config.max_model_len
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
self.layer_num = 1
# Unifying eagle, draft model, and parallel drafting support
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens
)
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0
)
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
self.parallel_drafting_token_id: int = 0
self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
if self.parallel_drafting:
self._init_parallel_drafting_params()
self.use_local_argmax_reduction: bool = (
self.speculative_config.use_local_argmax_reduction
)
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
# self.max_num_tokens = (
# vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
# )
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
......@@ -120,9 +93,6 @@ class SpecDecodeBaseProposer:
vllm_config.model_config
)
self.draft_attn_groups: list[AttentionGroup] = []
self.kv_cache_gid: int = -1
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
......@@ -146,8 +116,6 @@ class SpecDecodeBaseProposer:
# Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope
self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
......@@ -171,9 +139,6 @@ class SpecDecodeBaseProposer:
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
)
# Will be set when we initialize the attention backend
# self.block_size: int = -1
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
......@@ -181,26 +146,6 @@ class SpecDecodeBaseProposer:
max_num_slots_for_arange, device=device, dtype=torch.int32
)
if self.needs_extra_input_slots:
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_multimodal()
self._raise_if_mrope()
self.is_rejected_token_mask: torch.Tensor | None = None
self.is_masked_token_mask: torch.Tensor | None = None
if self.needs_extra_input_slots:
# For draft models and parallel drafting, we need to keep track of
# which tokens are rejected to update the slot mapping with padding slots.
self.is_rejected_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
# For parallel drafting, we also need to keep track of which tokens
# are parallel-padding tokens used to sample at later positions.
# We populate this tensor even when using draft models for simplicity.
self.is_masked_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.inputs_embeds_size),
dtype=self.dtype,
......@@ -221,6 +166,36 @@ class SpecDecodeBaseProposer:
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None
# if current_platform.is_rocm():
# from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata
# rocm_types = [
# TritonAttentionMetadata,
# RocmAttentionMetadata,
# ]
# # ROCM_AITER_FA is an optional backend
# if find_spec(
# AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
# ):
# from vllm.v1.attention.backends.rocm_aiter_fa import (
# AiterFlashAttentionMetadata,
# )
# rocm_types.append(AiterFlashAttentionMetadata)
# # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
# from vllm.model_executor.layers.attention.mla_attention import (
# MLACommonMetadata,
# )
# rocm_types.append(MLACommonMetadata)
# # FlexAttention backend support
# from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
# rocm_types.append(FlexAttentionMetadata)
# self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
......@@ -276,8 +251,7 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
# return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return {name: view for name in self._draft_attn_layer_names}
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
......@@ -296,23 +270,6 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
return (
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
)
def propose(
self,
# [num_tokens]
......@@ -323,10 +280,9 @@ class SpecDecodeBaseProposer:
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
last_token_indices: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
......@@ -342,28 +298,12 @@ class SpecDecodeBaseProposer:
)
assert target_hidden_states.shape[-1] == self.hidden_size
(
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
) = self.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
num_tokens, token_indices_to_sample, common_attn_metadata = (
num_tokens, last_token_indices, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=target_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
last_token_indices=last_token_indices,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
......@@ -415,9 +355,6 @@ class SpecDecodeBaseProposer:
# hidden dims. E.g. large target model and small draft model.
self.hidden_states[:num_tokens] = target_hidden_states
###### step3.5-mtp3新增
draft_token_ids_list = []
for spec_step_idx in range(self.layer_num):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
......@@ -438,13 +375,9 @@ class SpecDecodeBaseProposer:
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
if self.enable_multi_layers_mtp:
model_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
......@@ -462,65 +395,36 @@ class SpecDecodeBaseProposer:
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.enable_multi_layers_mtp:
logits = self.model.compute_logits(
sample_hidden_states, spec_step_idx=spec_step_idx
)
else:
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
# Generate the remaining draft tokens.
draft_token_ids_list.append(draft_token_ids)
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32)
if spec_step_idx < self.layer_num - 1:
prev_token_ids = self.input_ids[:num_tokens].clone()
hidden_states = hidden_states[:num_tokens]
next_token_ids = draft_token_ids_list[-1].int()
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=prev_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
if envs.VLLM_REJECT_SAMPLE_OPT:
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1])
# Early exit if all draft tokens are generated in one pass
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
##########################################################################
return draft_token_ids.view(-1, 1)
if self.uses_mrope:
positions = self.mrope_positions[:, token_indices_to_sample]
positions = self.mrope_positions[:, last_token_indices]
else:
positions = self.positions[token_indices_to_sample]
positions = self.positions[last_token_indices]
if self.method in (
"deepseek_mtp",
"ernie_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"step3p5_mtp", # 新增
):
hidden_states = self.hidden_states[token_indices_to_sample]
hidden_states = self.hidden_states[last_token_indices]
else:
hidden_states = hidden_states[token_indices_to_sample]
hidden_states = hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
######
if self.enable_multi_layers_mtp:
raise NotImplementedError(
"Speculative Decoding with multi-layer MTP and tree attention "
"is not supported yet."
)
#####
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
......@@ -533,22 +437,32 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
# draft_token_ids = logits.argmax(dim=-1)
draft_token_ids = logits.argmax(dim=-1)
if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types
):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > layer_num: "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}"
)
cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
self._determine_batch_execution_and_padding(batch_size)
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
......@@ -569,7 +483,7 @@ class SpecDecodeBaseProposer:
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs_list = [draft_prob]
for token_index in range(self.num_speculative_tokens - self.layer_num):
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
......@@ -648,9 +562,23 @@ class SpecDecodeBaseProposer:
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
)
else:
draft_indexer_metadata = None
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions)
......@@ -713,17 +641,12 @@ class SpecDecodeBaseProposer:
target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
last_token_indices: torch.Tensor | None,
cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if not self.needs_extra_input_slots:
# Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if token_indices_to_sample is None:
token_indices_to_sample = cad.query_start_loc[1:] - 1
if last_token_indices is None:
last_token_indices = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token.
......@@ -731,120 +654,12 @@ class SpecDecodeBaseProposer:
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[token_indices_to_sample] = next_token_ids
self.input_ids[last_token_indices] = next_token_ids
# copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
return num_tokens, token_indices_to_sample, cad
else:
assert self.is_rejected_token_mask is not None
assert self.is_masked_token_mask is not None
# 1.
# Call a custom triton kernel to copy input_ids and positions
# into the correct slots in the preallocated buffers self.input_ids,
# self.positions.
batch_size = cad.batch_size()
# Since we might have to copy a lot of data for prefills, we select the
# block size based on the max query length and limit to max 256 slots/block.
max_num_tokens_per_request = (
cad.max_query_len + self.net_num_new_slots_per_request
)
BLOCK_SIZE_TOKENS = min(
256, triton.next_power_of_2(max_num_tokens_per_request)
)
num_blocks = (
max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
) // BLOCK_SIZE_TOKENS
total_num_input_tokens = target_token_ids.shape[0]
total_num_output_tokens = total_num_input_tokens + (
self.net_num_new_slots_per_request * batch_size
)
token_indices_to_sample = torch.empty(
batch_size * self.extra_slots_per_request,
dtype=torch.int32,
device=self.device,
)
# Destination indices to write target_hidden_states into drafting buffer.
out_hidden_state_mapping = torch.empty(
total_num_input_tokens, dtype=torch.int32, device=self.device
)
# Kernel grid: one program per request (row)
grid = (batch_size, num_blocks)
query_start_loc = cad.query_start_loc
query_end_loc = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
query_end_loc = query_end_loc - num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel[grid](
# (Padded) Inputs from the target model
target_token_ids_ptr=target_token_ids,
target_positions_ptr=target_positions,
next_token_ids_ptr=next_token_ids, # sampled tokens, one per request
# Outputs to the drafting buffers
out_input_ids_ptr=self.input_ids,
out_positions_ptr=self.positions, # Doesn't support mrope for now
out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
out_is_masked_token_mask_ptr=self.is_masked_token_mask,
out_new_token_indices_ptr=token_indices_to_sample,
out_hidden_state_mapping_ptr=out_hidden_state_mapping,
# Input metadata
query_start_loc_ptr=query_start_loc,
query_end_loc_ptr=query_end_loc,
padding_token_id=0,
parallel_drafting_token_id=self.parallel_drafting_token_id,
# Sizing info
# Note that we can deduce batch_size for free from the grid size
total_input_tokens=total_num_input_tokens,
num_padding_slots_per_request=self.extra_slots_per_request,
shift_input_ids=self.pass_hidden_states_to_model,
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
if self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.hidden_states[out_hidden_state_mapping] = target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask = self.is_masked_token_mask[:total_num_output_tokens]
torch.where(
mask.unsqueeze(1),
self.parallel_drafting_hidden_state_tensor,
self.hidden_states[:total_num_output_tokens],
out=self.hidden_states[:total_num_output_tokens],
)
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert len(self.draft_attn_groups) > 0
block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
# 3. Update the common attention metadata with the new (meta)data
new_cad = extend_all_queries_by_N(
cad,
N=self.net_num_new_slots_per_request,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
return total_num_output_tokens, token_indices_to_sample, new_cad
return num_tokens, last_token_indices, cad
def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model")
......@@ -1281,28 +1096,10 @@ class SpecDecodeBaseProposer:
model = model.module
return model.__class__.__name__
def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config,
# load_config=self.speculative_config.draft_load_config,
)
return model
def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set(
get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names = set(
......@@ -1310,26 +1107,23 @@ class SpecDecodeBaseProposer:
self.vllm_config, DeepseekV32IndexerCache
).keys()
)
self.model = self._get_model()
# Find draft layers (attention layers added by draft model)
# all_attn_layers = get_layers_from_vllm_config(
# self.vllm_config,
# AttentionLayerBase, # type: ignore[type-abstract]
# )
# self._draft_attn_layer_names = (
# set(all_attn_layers.keys()) - target_attn_layer_names
# )
self._draft_attn_layer_names = (
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
- target_attn_layer_names
)
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(self._draft_attn_layer_names - draft_indexer_layer_names)
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names:
......@@ -1366,7 +1160,6 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
......@@ -1384,34 +1177,12 @@ class SpecDecodeBaseProposer:
else:
target_language_model = target_model
self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(
self.model.mask_hidden.view(3 * self.hidden_size)
)
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
inner_model = getattr(target_language_model, "model", None)
if inner_model is None:
raise AttributeError("Target model does not have 'model' attribute")
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
if hasattr(target_language_model.model, "embed_tokens"):
target_embed_tokens = target_language_model.model.embed_tokens
elif hasattr(target_language_model.model, "embedding"):
target_embed_tokens = target_language_model.model.embedding
else:
raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
......@@ -1466,12 +1237,7 @@ class SpecDecodeBaseProposer:
" from the target model."
)
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own LM head, and some may have a
duplicate copy of the target model's LM head. In these cases, we share
the target model's LM head with the draft model to save memory.
"""
# share lm_head with the target model if needed
share_lm_head = False
if hasattr(self.model, "has_own_lm_head"):
# EAGLE model
......@@ -1515,30 +1281,6 @@ class SpecDecodeBaseProposer:
del self.model.lm_head
self.model.lm_head = target_language_model.lm_head
if self.use_local_argmax_reduction:
if not hasattr(self.model, "get_top_tokens"):
raise ValueError(
"use_local_argmax_reduction is enabled but draft model "
f"{self.model.__class__.__name__} does not implement "
"get_top_tokens()."
)
# Warn if draft model has vocab remapping, which forces fallback
# to the full-logits path (negating the optimization).
if (
hasattr(self.model, "draft_id_to_target_id")
and self.model.draft_id_to_target_id is not None
):
logger.warning(
"use_local_argmax_reduction is enabled but draft model "
"uses draft_id_to_target_id vocab remapping. The "
"optimization will be bypassed (falling back to full "
"logits gather + argmax)."
)
else:
logger.info(
"Using local argmax reduction for draft token generation "
"(communication: O(2*tp_size) vs O(vocab_size))."
)
@torch.inference_mode()
def dummy_run(
self,
......@@ -1569,9 +1311,9 @@ class SpecDecodeBaseProposer:
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self._draft_attn_layer_names
self.attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
......@@ -1665,64 +1407,6 @@ class SpecDecodeBaseProposer:
== 1
), "All drafting layers should belong to the same kv cache group"
# def initialize_attn_backend(
# self,
# kv_cache_config: KVCacheConfig,
# kernel_block_sizes: list[int] | None = None,
# ) -> None:
# """
# Initialize AttentionGroups for draft layers using kv_cache_config.
# Called from the model runner's initialize_metadata_builders.
# """
# all_attn_layers = get_layers_from_vllm_config(
# self.vllm_config,
# AttentionLayerBase, # type: ignore[type-abstract]
# )
# # Find which kv_cache_group the draft layers belong to
# self.validate_same_kv_cache_group(kv_cache_config)
# kv_cache_spec = None
# for gid, group in enumerate(kv_cache_config.kv_cache_groups):
# if self._draft_attn_layer_names & set(group.layer_names):
# self.kv_cache_gid = gid
# kv_cache_spec = group.kv_cache_spec
# break
# attention_groups: dict[tuple[str, str], AttentionGroup] = {}
# if kv_cache_spec is not None:
# for layer_name in self._draft_attn_layer_names:
# attn_backend = all_attn_layers[layer_name].get_attn_backend()
# backend_key = attn_backend.full_cls_name()
# if backend_key not in attention_groups:
# layer_kv_cache_spec = kv_cache_spec
# if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
# layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
# layer_name
# ]
# kernel_block_size = (
# kernel_block_sizes[self.kv_cache_gid]
# if kernel_block_sizes is not None
# and self.kv_cache_gid < len(kernel_block_sizes)
# else None
# )
# attn_group = AttentionGroup(
# backend=attn_backend,
# layer_names=[layer_name],
# kv_cache_spec=layer_kv_cache_spec,
# kv_cache_group_id=self.kv_cache_gid,
# )
# attn_group.create_metadata_builders(
# self.vllm_config,
# self.device,
# kernel_block_size=kernel_block_size,
# )
# attention_groups[backend_key] = attn_group
# else:
# attention_groups[backend_key].layer_names.append(layer_name)
# self.draft_attn_groups = list(attention_groups.values())
def _pad_batch_across_dp(
self,
num_tokens_unpadded: int,
......@@ -1747,50 +1431,6 @@ class SpecDecodeBaseProposer:
return num_tokens_dp_padded, num_toks_across_dp
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
class EagleProposer(SpecDecodeBaseProposer):
def __init__(
self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1
......@@ -67,41 +67,3 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
logger = init_logger(__name__)
BLOCK_HIDDEN = 128
BLOCK_TOKENS = 128
class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
assert multi_layer_eagle_metadata is not None
if token_indices_to_sample is None:
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0] if target_positions.dim() == 2 else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=multi_layer_eagle_metadata.cached_token_ids,
cached_prev_positions=multi_layer_eagle_metadata.cached_positions,
cached_prev_hidden_states=multi_layer_eagle_metadata.cached_hidden_states,
cached_slot_mappings=multi_layer_eagle_metadata.cached_slot_mappings,
common_attn_metadata=common_attn_metadata,
)
return prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1], dtype=torch.int32, device=self.device
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device=self.device
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device="cpu"
),
seq_lens=torch.tensor(
[num_input_tokens], dtype=torch.int32, device=self.device
),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=num_input_tokens,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)
def _multi_layer_eagle_shift_and_cache(
*,
batch_size: int,
max_shift: int,
src_token_ids: torch.Tensor,
dst_token_ids: torch.Tensor,
src_positions: torch.Tensor,
dst_positions: torch.Tensor,
src_hidden_states: torch.Tensor,
dst_hidden_states: torch.Tensor,
src_slot_mapping: torch.Tensor,
dst_slot_mapping: torch.Tensor,
start_token_indices: torch.Tensor,
end_token_indices: torch.Tensor,
token_indices_to_sample: torch.Tensor,
shift: torch.Tensor,
cached_lens: torch.Tensor,
cached_prev_token_ids: torch.Tensor,
cached_prev_positions: torch.Tensor,
cached_prev_hidden_states: torch.Tensor,
cached_slot_mappings: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
):
if batch_size == 0:
return
assert max_shift > 0
assert cached_prev_positions.is_contiguous()
assert cached_prev_token_ids.is_contiguous()
assert cached_prev_hidden_states.is_contiguous()
assert cached_slot_mappings.is_contiguous()
assert src_hidden_states.is_contiguous()
assert dst_hidden_states.is_contiguous()
# If src/dst are the same tensor, shifting is unsafe without a separate src.
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
src_slot_mapping = src_slot_mapping.clone()
# Cache extraction for the next call.
store_start = torch.maximum(
start_token_indices,
(token_indices_to_sample + 1 - max_shift),
)
store_lens = torch.clamp(
token_indices_to_sample - store_start + 1,
min=0,
max=max_shift,
)
# Avoid device sync: query length == (end - start + 1) == diff of
# query_start_loc (CPU copy).
max_window_len = int(
(
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
.max()
.item()
)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_token_ids,
dst_token_ids,
cached_prev_token_ids,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_slot_mapping,
dst_slot_mapping,
cached_slot_mappings,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_positions,
dst_positions,
cached_prev_positions,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
hidden_size = int(dst_hidden_states.shape[1])
# Hidden blocking avoids extremely large Triton tiles (and huge cubins)
# when hidden_size is large.
num_hidden_blocks = max(1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN)
_shift_and_gather_hidden_kernel[(batch_size, num_blocks, num_hidden_blocks)](
src_hidden_states,
dst_hidden_states,
cached_prev_hidden_states,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
HIDDEN_SIZE=hidden_size,
BLOCK_TOKENS=BLOCK_TOKENS,
BLOCK_HIDDEN=BLOCK_HIDDEN,
num_warps=4,
)
cached_lens.copy_(store_lens)
return
@triton.jit
def _shift_and_gather_cache_1d_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# We operate on a packed batch where each sequence (request) occupies a
# contiguous window [start, end] (inclusive) in a flattened tensor.
# For the next speculative step, we build a right-shifted version of each
# window. The shift amount can differ per sequence.
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
#
# The vacated prefix is filled from a small per-sequence cache (up to
# MAX_SHIFT elements) that stores values from previous speculative steps.
#
# Example:
# cached_tail = [a3, a4]
# src_window = [b0, b1, b2, b3, b4]
# shift = 2
# -> dst_window = [a3, a4, b0, b1, b2]
#
# After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst
# (specified by store_start / store_len) so the next call can populate its
# prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
offs = base + k
dst_idx = start + offs
# get dst mask
window_len = end - start + 1
mask = offs < window_len
# load from cached
base_cached = cached_ptr + pid_seq * MAX_SHIFT
cached_idx = cached_len - shift + offs
cached_mask = offs < shift
val_cached = tl.load(base_cached + cached_idx, mask=mask & cached_mask, other=0)
# load from src
src_idx = start + offs - shift
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
# store to dst
val = tl.where(cached_mask, val_cached, val_src)
tl.store(dst_ptr + dst_idx, val, mask=mask)
# Store into the per-sequence cache.
#
# Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the
# full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the
# cache contiguous.
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
store_mask = m < MAX_SHIFT
dst_idx = store_start + m
val = tl.load(dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0)
tl.store(base_cached + m, val, mask=store_mask)
@triton.jit
def _shift_and_gather_hidden_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_HIDDEN: tl.constexpr,
):
# Per-sequence "shift + gather" for hidden states.
#
# This kernel implements the same logical transformation as
# _shift_and_gather_cache_1d_kernel, but operates on hidden states with
# shape [num_tokens, hidden_size].
#
# Layout:
# - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size]
# - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size]
#
# For each sequence window [start, end] (inclusive) and its shift value, for
# 0-based index i within the window:
# - Prefix (i < shift):
# dst[start + i, :] = cached[seq, cached_len - shift + i, :]
# - Body (i >= shift):
# dst[start + i, :] = src[start + i - shift, :]
#
# We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid
# extremely large Triton tiles when hidden_size is large. As in the 1D
# kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next
# call can populate its prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
pid_hid = tl.program_id(2)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
tok_offs = base + k
dst_tok = start + tok_offs
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
# get dst mask
window_len = end - start + 1
tok_mask = tok_offs < window_len
n_mask = n < HIDDEN_SIZE
mask = tok_mask[:, None] & n_mask[None, :]
# load from cached
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
cached_tok = cached_len - shift + tok_offs
cached_ptrs = base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
cached_mask = tok_offs < shift
val_cached = tl.load(cached_ptrs, mask=mask & cached_mask[:, None], other=0)
# load from src
src_tok = start + tok_offs - shift
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
# store to dst
val = tl.where(cached_mask[:, None], val_cached, val_src)
tl.store(dst_ptrs, val, mask=mask)
# store to cached
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
m_mask = (m < MAX_SHIFT) & (m < store_len)
store_tok = store_start + m
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
store_ptrs = base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
mask = m_mask[:, None] & n_mask[None, :]
val = tl.load(dst_ptrs, mask=mask, other=0)
tl.store(store_ptrs, val, mask=mask)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
)
from vllm.forward_context import set_forward_context
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@support_torch_compile()
class NgramGPUKernel(nn.Module):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def __init__(
self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
):
super().__init__()
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
def _find_first_and_extract_all_n_parallel(
self,
token_ids: torch.Tensor,
seq_lengths: torch.Tensor,
min_ngram_len: int,
max_ngram_len: int,
num_draft_tokens: int,
) -> torch.Tensor:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size = token_ids.shape[0]
max_seq_len = token_ids.shape[1]
device = token_ids.device
num_ngram_sizes = max_ngram_len - min_ngram_len + 1
# All n-gram sizes to try.
ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
batch_indices = torch.arange(batch_size, device=device)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions = torch.full(
(batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
)
for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows = token_ids.unfold(1, ngram_len, 1)
num_windows = search_windows.shape[1]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts = seq_lengths - ngram_len
suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
ngram_len, device=device
)
suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
# Window matches for each sequence.
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
# Match must leave room for at least one draft token.
max_valid_suffix_start = seq_lengths - ngram_len - 1
window_positions = torch.arange(num_windows, device=device)
valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
final_matches = matches & valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx = torch.argmax(final_matches.int(), dim=1)
has_match = final_matches[batch_indices, first_match_idx]
# Store valid match positions (window index = position).
first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)
# Select the longest n-gram with a match.
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back
# Match position for the best n-gram.
best_match_pos = first_match_positions[batch_indices, best_ngram_idx]
# Avoid data-dependent branching.
has_any_match = best_match_pos >= 0
# Length of the best matching n-gram.
best_ngram_lengths = ngram_lengths[best_ngram_idx]
# Start position right after the matched suffix.
draft_start = torch.where(
has_any_match,
best_match_pos + best_ngram_lengths,
torch.zeros_like(best_match_pos),
)
tokens_available = seq_lengths - draft_start
# Gather indices for draft tokens.
draft_indices = draft_start.unsqueeze(1) + torch.arange(
num_draft_tokens, device=device
)
draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)
# Extract draft tokens; gather always runs.
draft_tokens = torch.gather(token_ids, 1, draft_indices)
# Mask positions beyond available tokens.
position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
valid_positions = position_indices < tokens_available.unsqueeze(1)
draft_tokens = torch.where(
valid_positions,
draft_tokens,
torch.full_like(draft_tokens, -1),
)
# If no match, mask all positions.
draft_tokens = torch.where(
has_any_match.unsqueeze(1),
draft_tokens,
torch.full_like(draft_tokens, -1),
)
return draft_tokens
def forward(
self,
num_tokens_no_spec: torch.Tensor,
token_ids_gpu: torch.Tensor,
combined_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device = token_ids_gpu.device
# Infer batch size to preserve dynamic shape.
actual_batch_size = token_ids_gpu.shape[0]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens = torch.full(
(actual_batch_size, self.k), -1, dtype=torch.int32, device=device
)
results = self._find_first_and_extract_all_n_parallel(
token_ids_gpu,
num_tokens_no_spec,
min_ngram_len=self.min_n,
max_ngram_len=self.max_n,
num_draft_tokens=self.k,
)
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
# Count leading contiguous valid (non -1) tokens per request.
is_valid = draft_tokens != -1 # [batch, k]
cum_valid = is_valid.int().cumsum(dim=1) # [batch, k]
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
return draft_tokens, num_valid_draft_tokens
def load_model(self, *args, **kwargs):
"""No model to load for N-gram proposer."""
pass
class NgramProposerGPU:
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["none"],
splitting_ops=[],
compile_sizes=[],
inductor_compile_config={
"enable_auto_functionalized_v2": False,
"max_autotune": True,
"aggressive_fusion": True,
"triton.autotune_pointwise": True,
"coordinate_descent_tuning": True,
"use_mixed_mm": False,
},
cudagraph_mode=CUDAGraphMode.NONE,
)
model_config = vllm_config.model_config
speculative_config = vllm_config.speculative_config
scheduler_config = vllm_config.scheduler_config
self.vllm_config = VllmConfig(
compilation_config=compilation_config,
model_config=model_config,
speculative_config=speculative_config,
scheduler_config=scheduler_config,
)
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
self.kernel = NgramGPUKernel(
vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
)
self.kernel.to(device)
self.kernel.eval()
self._dummy_run()
def _dummy_run(self):
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
batch_size=self.max_num_seqs,
max_seq_len=self.max_model_len,
pattern_len=self.k,
device=self.device,
)
combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)
for _ in range(3):
with set_forward_context(None, self.vllm_config):
_, _ = self.kernel(num_tokens, token_ids, combined_mask)
def _generate_dummy_data(
self,
batch_size: int,
max_seq_len: int,
pattern_len: int,
device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids = torch.zeros(
batch_size,
max_seq_len,
dtype=torch.int32,
device=device,
)
num_tokens = torch.randint(
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
)
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
return token_ids, num_tokens, sampled_flags, valid_mask
def propose(
self,
num_tokens_no_spec: torch.Tensor, # [batch_size]
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert token_ids_gpu.device == self.device
assert num_tokens_no_spec.device == self.device
batch_size = num_tokens_no_spec.shape[0]
max_seq_len = token_ids_gpu.shape[1]
max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets = torch.arange(max_new_tokens, device=self.device)
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
1
)
in_bounds = write_positions < max_seq_len
scatter_mask = (
valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
)
write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
existing_values = token_ids_gpu.gather(1, write_positions_long)
tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
tokens_to_scatter = torch.where(
scatter_mask,
tokens_cast,
existing_values,
)
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
# Compute validity masks.
sampled_flags = valid_sampled_tokens_count > 0
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
with set_forward_context(None, self.vllm_config):
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
draft_tokens, num_valid_draft_tokens = self.kernel(
num_tokens_tmp,
token_ids_gpu,
combined_mask,
)
return draft_tokens, num_valid_draft_tokens
def update_token_ids_ngram(
self,
sampled_token_ids: torch.Tensor | list[list[int]],
gpu_input_batch: InputBatch,
token_ids_gpu: torch.Tensor,
num_tokens_no_spec: torch.Tensor,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs = gpu_input_batch.num_reqs
if isinstance(sampled_token_ids, list):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len = max(
(len(sublist) for sublist in sampled_token_ids),
default=0,
)
# Ensure at least length 1 for tensor creation
max_len = max(max_len, 1)
padded_list = [
sublist + [-1] * (max_len - len(sublist))
for sublist in sampled_token_ids
]
sampled_token_ids = torch.tensor(
padded_list, dtype=torch.int32, device=self.device
)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
backup_next_token_ids = torch.gather(
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
).squeeze(1)
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)
# Mask valid tokens within each request.
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Rightmost valid index per row.
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Last valid token from each row; undefined if none.
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid; otherwise fallback to backup.
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
backup_next_token_ids,
)
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
def load_model(self, *args, **kwargs):
self.kernel.load_model(*args, **kwargs)
def update_scheduler_for_invalid_drafts(
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens_cpu: torch.Tensor,
scheduler_output: "SchedulerOutput",
req_id_to_index: dict[str, int],
) -> None:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data = scheduler_output.scheduled_cached_reqs
num_valid_draft_tokens_event.synchronize()
for req_id in req_data.req_ids:
req_index = req_id_to_index.get(req_id)
if req_index is None:
continue
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if spec_token_ids is None:
continue
scheduled_k = len(spec_token_ids)
valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
valid_k = max(0, min(valid_k, scheduled_k))
tokens_to_trim = scheduled_k - valid_k
scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim
if valid_k == 0:
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
else:
scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
:valid_k
]
def update_ngram_gpu_tensors_incremental(
input_batch: InputBatch,
token_ids_gpu_tensor: torch.Tensor,
num_tokens_no_spec_gpu: torch.Tensor,
new_reqs: list[CachedRequestState],
device: torch.device,
_pinned_idx_buf: torch.Tensor,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index = input_batch.prev_req_id_to_index
curr_req_id_to_index = input_batch.req_id_to_index
if not curr_req_id_to_index:
return
active_indices = list(curr_req_id_to_index.values())
n_active = len(active_indices)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu = _pinned_idx_buf[:n_active]
active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))
active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)
new_req_ids = {req.req_id for req in new_reqs}
# First run, no previous state.
if prev_req_id_to_index is None:
for idx in active_indices:
num_tokens = input_batch.num_tokens_no_spec[idx]
if num_tokens > 0:
token_ids_gpu_tensor[idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[idx, :num_tokens],
non_blocking=True,
)
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
return
# Detect index changes for reorder.
reorder_src: list[int] = []
reorder_dst: list[int] = []
for req_id, curr_idx in curr_req_id_to_index.items():
if req_id in new_req_ids:
continue
prev_idx = prev_req_id_to_index.get(req_id)
if prev_idx is not None and prev_idx != curr_idx:
reorder_src.append(prev_idx)
reorder_dst.append(curr_idx)
if reorder_src:
src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)
temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()
token_ids_gpu_tensor[dst_tensor] = temp_token_ids
num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens
# Full copy for new/resumed requests.
for req_state in new_reqs:
new_req_idx = curr_req_id_to_index.get(req_state.req_id)
if new_req_idx is None:
continue
num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
if num_tokens > 0:
token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
non_blocking=True,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
def _sync_num_tokens(
input_batch: InputBatch,
num_tokens_no_spec_gpu: torch.Tensor,
active_idx_cpu: torch.Tensor,
active_idx_gpu: torch.Tensor,
n_active: int,
device: torch.device,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
vals = _pinned_val_buf[:n_active]
vals.copy_(src_cpu.index_select(0, active_idx_cpu))
num_tokens_no_spec_gpu.index_copy_(
0,
active_idx_gpu,
vals.to(device=device, non_blocking=True),
)
def copy_num_valid_draft_tokens(
num_valid_draft_tokens_cpu: torch.Tensor,
num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens: torch.Tensor | None,
batch_size: int,
) -> None:
"""
Async D2H copy of per-request valid draft counts.
"""
if num_valid_draft_tokens is None:
return
num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
if num_reqs_to_copy <= 0:
return
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
)
num_valid_draft_tokens_event.record()
......@@ -5,11 +5,7 @@ import torch
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
PADDING_SLOT_ID = -1
@triton.jit
def eagle_prepare_inputs_padded_kernel(
......@@ -186,219 +182,3 @@ class DraftProbs(ABC): # type: ignore[call-arg]
target_device=self.draft_probs.device,
pin_memory=True)
return self.draft_probs[index_tensor]
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
num_new_tokens: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices,
cad.naive_query_lens() + num_new_tokens,
output_size=len(new_positions),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + N,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
# All query lens increase by N, so max query len increases by N
max_query_len=cad.max_query_len + N,
max_seq_len=cad.max_seq_len + N,
slot_mapping=new_slot_mapping,
)
return new_cad
# Unified copy/expand kernel
@triton.jit
def copy_and_expand_eagle_inputs_kernel(
# (Padded) Inputs from the target model
target_token_ids_ptr, # [total_tokens_in_batch]
target_positions_ptr, # [total_tokens_in_batch]
next_token_ids_ptr, # [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr, # [num_reqs]
padding_token_id, # tl.int32
parallel_drafting_token_id, # tl.int32
# Sizing info
total_input_tokens, # tl.int32
num_padding_slots_per_request, # tl.int32
shift_input_ids, # tl.bool
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx = tl.program_id(axis=0)
token_batch_idx = tl.program_id(axis=1)
# Load query locations
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if shift_input_ids:
num_valid_tokens = query_end_loc - query_start_loc
input_offset = 1
output_start = query_start_loc + request_idx * (
num_padding_slots_per_request - 1
)
else:
num_valid_tokens = query_end_loc - query_start_loc + 1
input_offset = 0
output_start = query_start_loc + request_idx * num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected = next_query_start_loc - query_end_loc - 1
# Total output tokens for this request
total_output_tokens = (
num_valid_tokens + num_padding_slots_per_request + num_rejected
)
# Process tokens in this block
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds = j < total_output_tokens
is_valid_region = j < num_valid_tokens
is_bonus_region = j == num_valid_tokens
is_parallel_draft_region = (j > num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
# Compute output indices
out_idx = output_start + j
# For valid tokens, compute input index
in_idx = query_start_loc + input_offset + j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
# Load input tokens (masked to valid region)
token_ids = tl.load(
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
)
# Load the starting position for this request (first position in the sequence)
start_pos = tl.load(target_positions_ptr + query_start_loc)
# Load bonus token for this request
bonus_token = tl.load(next_token_ids_ptr + request_idx)
# Build final token_ids based on region
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
token_ids = tl.where(
is_parallel_draft_region, parallel_drafting_token_id, token_ids
)
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions = start_pos + j
# Rejected positions are don't-care, set to 0
positions = tl.where(is_rejected_region, 0, positions)
# Compute output masks
is_rejected_out = is_rejected_region & in_bounds
is_masked_out = is_parallel_draft_region & in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region = (j >= num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
new_token_local_idx = (
j - num_valid_tokens
) # 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx = (
request_idx * num_padding_slots_per_request + new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if shift_input_ids:
num_input_tokens_this_request = next_query_start_loc - query_start_loc
is_input_region = j < num_input_tokens_this_request
src_idx = query_start_loc + j
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
# Store outputs
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
tl.store(
out_new_token_indices_ptr + new_token_out_idx,
out_idx,
mask=is_new_token_region & in_bounds,
)
\ No newline at end of file
......@@ -61,13 +61,6 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds
......@@ -110,8 +103,6 @@ class InputBatch:
is_spec_decode: bool = False,
is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
......@@ -232,45 +223,7 @@ class InputBatch:
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
......@@ -511,13 +464,6 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
......@@ -716,20 +662,6 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1],
)
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
......@@ -852,21 +784,6 @@ class InputBatch:
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[last_req_index]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
......
......@@ -149,15 +149,8 @@ from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata, MultiLayerEagleMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import (
copy_num_valid_draft_tokens,
# update_ngram_gpu_tensors_incremental,
# update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
......@@ -323,7 +316,6 @@ class ExecuteModelState(NamedTuple):
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
......@@ -344,7 +336,6 @@ class GPUModelRunner(
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# self.offload_config = vllm_config.offload_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
......@@ -426,9 +417,6 @@ class GPUModelRunner(
# Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# multi layer eagle
self.enable_multi_layer_eagle = False
self.eplb_state: EplbState | None = None
"""
State of the expert parallelism load balancer.
......@@ -451,9 +439,6 @@ class GPUModelRunner(
self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False
self.multi_layer_eagle_num = 0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
......@@ -465,7 +450,6 @@ class GPUModelRunner(
| EagleProposer
| DraftModelProposer
| MedusaProposer
| ExtractHiddenStatesProposer
)
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
......@@ -478,19 +462,7 @@ class GPUModelRunner(
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
if (
self.speculative_config.enable_multi_layers_mtp
and self.speculative_config.method == "mtp"
):
self.enable_multi_layer_eagle = True
self.drafter = MultiLayerEagleProposer(
self.vllm_config, self.device, self
)
self.multi_layer_eagle_num = self.drafter.layer_num
else:
self.drafter = EagleProposer(self.vllm_config, self.device, self)
# self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state
......@@ -499,11 +471,6 @@ class GPUModelRunner(
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device
)
elif self.speculative_config.method == "extract_hidden_states":
self.drafter = ExtractHiddenStatesProposer(
vllm_config=self.vllm_config, device=self.device
)
self.use_aux_hidden_state_outputs = True
else:
raise ValueError(
"Unknown speculative decoding method: "
......@@ -568,10 +535,6 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
......@@ -660,7 +623,6 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0:
# Similar to mrope but use assigned dimension number for RoPE, 4 as default.
......@@ -843,6 +805,7 @@ class GPUModelRunner(
pin_memory=self.pin_memory,
with_numpy=numpy,
)
def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
if not self.uses_mrope:
return
......@@ -853,6 +816,10 @@ class GPUModelRunner(
non_blocking=True,
)
return
# self.mrope_positions.gpu[:, :num_tokens].copy_(
# self.mrope_positions.cpu[:, :num_tokens],
# non_blocking=True,
# )
self.mrope_positions.gpu[:, :num_tokens].copy_(
self.mrope_positions.cpu[:, :num_tokens].contiguous().pin_memory(),
non_blocking=True,
......@@ -1051,9 +1018,6 @@ class GPUModelRunner(
if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state)
if self.enable_multi_layer_eagle:
self._init_multi_layer_eagle_cache(req_state)
reqs_to_add.append(req_state)
# Update the states of the running/resumed requests.
......@@ -1305,24 +1269,6 @@ class GPUModelRunner(
req_state.mm_features,
)
def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState):
req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device)
req_state.cached_hidden_states = torch.zeros(
self.multi_layer_eagle_num,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device,
)
req_state.cached_token_ids = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int32, device=self.device
)
req_state.cached_positions = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
req_state.cached_slot_mappings = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
def _extract_mm_kwargs(
self,
scheduler_output: "SchedulerOutput",
......@@ -1747,17 +1693,6 @@ class GPUModelRunner(
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu()
if self.enable_multi_layer_eagle:
multi_layer_eagle_metadata = MultiLayerEagleMetadata(
cached_len=self.input_batch.cached_len[:num_reqs],
cached_token_ids=self.input_batch.cached_token_ids[:num_reqs],
cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs],
cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs],
cached_positions=self.input_batch.cached_positions[:num_reqs],
)
else:
multi_layer_eagle_metadata = None
# Hot-Swap lora model
if self.lora_config:
assert (
......@@ -1768,11 +1703,10 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
# return (
# logits_indices,
# spec_decode_metadata,
# )
return (logits_indices, spec_decode_metadata, multi_layer_eagle_metadata)
return (
logits_indices,
spec_decode_metadata,
)
def _build_attention_metadata(
self,
......@@ -2238,9 +2172,9 @@ class GPUModelRunner(
req.mrope_positions[:, src_start:src_end].transpose(0, 1)
)
else:
self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
:, src_start:src_end
]
self.mrope_positions.cpu[:, dst_start:dst_end] = (
req.mrope_positions[:, src_start:src_end]
)
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
......@@ -2251,7 +2185,9 @@ class GPUModelRunner(
assert req.mrope_position_delta is not None
if self.use_1d_mrope:
values = np.arange(
req.mrope_position_delta + num_computed_tokens + prompt_part_len,
req.mrope_position_delta
+ num_computed_tokens
+ prompt_part_len,
req.mrope_position_delta
+ num_computed_tokens
+ prompt_part_len
......@@ -3525,16 +3461,10 @@ class GPUModelRunner(
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
# logits_indices, spec_decode_metadata = self._prepare_inputs(
# scheduler_output,
# num_scheduled_tokens_np,
# )
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = (
self._prepare_inputs(
logits_indices, spec_decode_metadata = self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
)
)
cascade_attn_prefix_lens = None
# Disable cascade attention when using microbatching (DBO)
......@@ -3760,7 +3690,6 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3798,7 +3727,6 @@ class GPUModelRunner(
scheduler_output,
logits,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
......@@ -3838,7 +3766,6 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata,
slot_mappings,
)
......@@ -4042,233 +3969,6 @@ class GPUModelRunner(
sampled_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
# def propose_draft_token_ids(
# self,
# scheduler_output: "SchedulerOutput",
# sampled_token_ids: torch.Tensor | list[list[int]],
# sampling_metadata: SamplingMetadata,
# hidden_states: torch.Tensor,
# sample_hidden_states: torch.Tensor,
# aux_hidden_states: list[torch.Tensor] | None,
# spec_decode_metadata: SpecDecodeMetadata | None,
# # multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
# common_attn_metadata: CommonAttentionMetadata,
# slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
# ) -> list[list[int]] | torch.Tensor:
# num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# spec_config = self.speculative_config
# assert spec_config is not None
# if spec_config.method == "ngram":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, NgramProposer)
# draft_token_ids = self.drafter.propose(
# sampled_token_ids,
# self.input_batch.num_tokens_no_spec,
# self.input_batch.token_ids_cpu,
# slot_mappings=slot_mappings,
# )
# elif spec_config.method == "suffix":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, SuffixDecodingProposer)
# draft_token_ids = self.drafter.propose(
# self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
# )
# elif spec_config.method == "medusa":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, MedusaProposer)
# if sample_hidden_states.shape[0] == len(sampled_token_ids):
# # The input to the target model does not include draft tokens.
# hidden_states = sample_hidden_states
# else:
# indices = []
# offset = 0
# assert spec_decode_metadata is not None, (
# "No spec decode metadata for medusa"
# )
# for num_draft, tokens in zip(
# spec_decode_metadata.num_draft_tokens, sampled_token_ids
# ):
# indices.append(offset + len(tokens) - 1)
# offset += num_draft + 1
# indices = torch.tensor(indices, device=self.device)
# hidden_states = sample_hidden_states[indices]
# draft_token_ids = self.drafter.propose(
# target_hidden_states=hidden_states,
# sampling_metadata=sampling_metadata,
# slot_mappings=slot_mappings,
# )
# elif spec_config.uses_extract_hidden_states():
# assert isinstance(self.drafter, ExtractHiddenStatesProposer)
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor for "
# "extract_hidden_states method."
# )
# if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
# raise ValueError(
# "aux_hidden_states are required when using `extract_hidden_states`"
# )
# target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
# draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
# sampled_token_ids=sampled_token_ids,
# target_hidden_states=target_hidden_states,
# common_attn_metadata=common_attn_metadata,
# scheduler_output=scheduler_output,
# slot_mappings=slot_mappings,
# )
# # Combine KVConnectorOutputs or select the non-empty one
# if self.kv_connector_output and drafter_kv_connector_output:
# self.kv_connector_output = KVConnectorOutput.merge(
# self.kv_connector_output, drafter_kv_connector_output
# )
# else:
# self.kv_connector_output = (
# self.kv_connector_output or drafter_kv_connector_output
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# elif spec_config.use_eagle() or spec_config.uses_draft_model():
# assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# if spec_config.disable_padded_drafter_batch:
# # When padded-batch is disabled, the sampled_token_ids should be
# # the cpu-side list[list[int]] of valid sampled tokens for each
# # request, with invalid requests having empty lists.
# assert isinstance(sampled_token_ids, list), (
# "sampled_token_ids should be a python list when"
# "padded-batch is disabled."
# )
# next_token_ids = self.drafter.prepare_next_token_ids_cpu(
# sampled_token_ids,
# self.requests,
# self.input_batch,
# scheduler_output.num_scheduled_tokens,
# )
# else:
# # When using padded-batch, the sampled_token_ids should be
# # the gpu tensor of sampled tokens for each request, of shape
# # (num_reqs, num_spec_tokens + 1) with rejected tokens having
# # value -1.
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor when"
# "padded-batch is enabled."
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# num_rejected_tokens_gpu = None
# if spec_decode_metadata is None:
# token_indices_to_sample = None
# # input_ids can be None for multimodal models.
# target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# target_positions = self._get_positions(num_scheduled_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:num_scheduled_tokens]
# else:
# if spec_config.disable_padded_drafter_batch:
# token_indices_to_sample = None
# common_attn_metadata, token_indices = self.drafter.prepare_inputs(
# common_attn_metadata,
# sampled_token_ids,
# spec_decode_metadata.num_draft_tokens,
# )
# target_token_ids = self.input_ids.gpu[token_indices]
# target_positions = self._get_positions(token_indices)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[token_indices] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[token_indices]
# else:
# (
# common_attn_metadata,
# token_indices_to_sample,
# num_rejected_tokens_gpu,
# ) = self.drafter.prepare_inputs_padded(
# common_attn_metadata,
# spec_decode_metadata,
# valid_sampled_tokens_count,
# )
# total_num_tokens = common_attn_metadata.num_actual_tokens
# # When padding the batch, token_indices is just a range
# target_token_ids = self.input_ids.gpu[:total_num_tokens]
# target_positions = self._get_positions(total_num_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:total_num_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:total_num_tokens]
# # if self.supports_mm_inputs:
# if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
# mm_embed_inputs = self._gather_mm_embeddings(
# scheduler_output,
# shift_computed_tokens=1,
# )
# else:
# mm_embed_inputs = None
# draft_result = self.drafter.propose(
# target_token_ids=target_token_ids,
# target_positions=target_positions,
# target_hidden_states=target_hidden_states,
# next_token_ids=next_token_ids,
# token_indices_to_sample=token_indices_to_sample,
# sampling_metadata=sampling_metadata,
# common_attn_metadata=common_attn_metadata,
# mm_embed_inputs=mm_embed_inputs,
# num_rejected_tokens_gpu=num_rejected_tokens_gpu,
# slot_mappings=slot_mappings,
# # multi_layer_eagle_metadata=multi_layer_eagle_metadata,
# )
# if not envs.VLLM_REJECT_SAMPLE_OPT:
# draft_token_ids = draft_result
# else:
# draft_token_ids, draft_probs = draft_result
# if envs.VLLM_REJECT_SAMPLE_OPT:
# draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# if self.draft_probs is None:
# self.draft_probs = DraftProbs(
# draft_probs, draft_req_ids)
# else:
# self.draft_probs.update(draft_probs, draft_req_ids)
# return draft_token_ids
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
......@@ -4278,7 +3978,6 @@ class GPUModelRunner(
sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor:
......@@ -4286,8 +3985,6 @@ class GPUModelRunner(
spec_config = self.speculative_config
assert spec_config is not None
if spec_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose(
......@@ -4296,15 +3993,6 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
if isinstance(self.drafter, NgramProposer):
assert isinstance(sampled_token_ids, list), (
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
)
elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)
......@@ -4337,48 +4025,6 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata,
slot_mappings=slot_mappings,
)
elif spec_config.uses_extract_hidden_states():
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for "
"extract_hidden_states method."
)
if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
raise ValueError(
"aux_hidden_states are required when using `extract_hidden_states`"
)
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=scheduler_output,
slot_mappings=slot_mappings,
)
# Combine KVConnectorOutputs or select the non-empty one
if self.kv_connector_output and drafter_kv_connector_output:
self.kv_connector_output = KVConnectorOutput.merge(
self.kv_connector_output, drafter_kv_connector_output
)
else:
self.kv_connector_output = (
self.kv_connector_output or drafter_kv_connector_output
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
......@@ -4470,7 +4116,7 @@ class GPUModelRunner(
else:
target_hidden_states = hidden_states[:total_num_tokens]
if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
if self.supports_mm_inputs:
mm_embed_inputs = self._gather_mm_embeddings(
scheduler_output,
shift_computed_tokens=1,
......@@ -4483,16 +4129,28 @@ class GPUModelRunner(
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
token_indices_to_sample=token_indices_to_sample,
last_token_indices=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
return draft_result
if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_token_ids = draft_result
else:
draft_token_ids, draft_probs = draft_result
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
......@@ -6061,8 +5719,6 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model,
multi_layer_eagle_num=self.multi_layer_eagle_num if self.enable_multi_layer_eagle else 0,
hidden_size=self.model_config.get_hidden_size(),
)
def _allocate_kv_cache_tensors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment