Commit fcc9c9ea authored by luopl's avatar luopl
Browse files

feat:新增step3.5-mtp3功能

parent 9dc40d38
...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator ...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from vllm.config import LoadConfig
from vllm.config.model import ModelConfig from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config from vllm.config.utils import config
...@@ -76,6 +77,10 @@ class SpeculativeConfig: ...@@ -76,6 +77,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered.""" `prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1) draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
...@@ -110,6 +115,11 @@ class SpeculativeConfig: ...@@ -110,6 +115,11 @@ class SpeculativeConfig:
which may only be supported by certain attention backends. This currently which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation.""" only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration # Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1) prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required """Maximum size of ngram token window when using Ngram proposer, required
...@@ -121,6 +131,12 @@ class SpeculativeConfig: ...@@ -121,6 +131,12 @@ class SpeculativeConfig:
speculative_token_tree: str | None = None speculative_token_tree: str | None = None
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
parallel_drafting: bool = False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine # required configuration params passed from engine
target_model_config: SkipValidation[ModelConfig] = None # type: ignore target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model.""" """The configuration of the target model."""
...@@ -154,6 +170,10 @@ class SpeculativeConfig: ...@@ -154,6 +170,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than tokens with estimated probability (based on frequency counts) greater than
or equal to this value.""" or equal to this value."""
draft_load_config: LoadConfig | None = None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -401,7 +421,11 @@ class SpeculativeConfig: ...@@ -401,7 +421,11 @@ class SpeculativeConfig:
MTPModelTypes MTPModelTypes
): ):
self.method = "mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: # if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning( logger.warning(
"Enabling num_speculative_tokens > 1 will run" "Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer" "multiple times of forward on same MTP layer"
...@@ -472,6 +496,17 @@ class SpeculativeConfig: ...@@ -472,6 +496,17 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None: if self.num_speculative_tokens is None:
# Default to max value defined in draft model config. # Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict self.num_speculative_tokens = n_predict
elif (
self.method == "mtp"
and self.enable_multi_layers_mtp
and self.num_speculative_tokens > n_predict
):
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self.num_speculative_tokens = n_predict
elif ( elif (
self.num_speculative_tokens > n_predict self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0 and self.num_speculative_tokens % n_predict != 0
...@@ -713,12 +748,31 @@ class SpeculativeConfig: ...@@ -713,12 +748,31 @@ class SpeculativeConfig:
f"errors during speculative decoding." f"errors during speculative decoding."
) )
@property
def max_num_new_slots_for_drafting(self) -> int:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req = 0 # for serial non-draft-model methods, no change needed
if self.parallel_drafting:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req = self.num_speculative_tokens - 1
if self.uses_draft_model():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req += 1
return slots_per_req
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "mtp") return self.method in ("eagle", "eagle3", "mtp")
def uses_draft_model(self) -> bool: def uses_draft_model(self) -> bool:
return self.method == "draft_model" return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model model = None if method in ("ngram", "suffix") else self.draft_model_config.model
......
...@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel): ...@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context): def model_post_init(self, __context):
if not self.id: if not self.id:
self.id = f"msg_{int(time.time() * 1000)}" self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None
\ No newline at end of file
...@@ -15,6 +15,9 @@ from fastapi import Request ...@@ -15,6 +15,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import ( from vllm.entrypoints.anthropic.protocol import (
AnthropicContextManagement,
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicContentBlock, AnthropicContentBlock,
AnthropicDelta, AnthropicDelta,
AnthropicError, AnthropicError,
...@@ -112,6 +115,7 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -112,6 +115,7 @@ class AnthropicServingMessages(OpenAIServingChat):
# Handle complex content blocks # Handle complex content blocks
content_parts: list[dict[str, Any]] = [] content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content: for block in msg.content:
if block.type == "text" and block.text: if block.type == "text" and block.text:
...@@ -123,6 +127,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -123,6 +127,8 @@ class AnthropicServingMessages(OpenAIServingChat):
"image_url": {"url": block.source.get("data", "")}, "image_url": {"url": block.source.get("data", "")},
} }
) )
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use": elif block.type == "tool_use":
# Convert tool use to function call format # Convert tool use to function call format
tool_call = { tool_call = {
...@@ -157,6 +163,9 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -157,6 +163,9 @@ class AnthropicServingMessages(OpenAIServingChat):
} }
) )
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
# Add tool calls to the message if any # Add tool calls to the message if any
if tool_calls: if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore openai_msg["tool_calls"] = tool_calls # type: ignore
...@@ -297,10 +306,116 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -297,10 +306,116 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None], generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
try: try:
class _ActiveBlockState:
def __init__(self) -> None:
self.content_block_index = 0
self.block_type: str | None = None
self.block_index: int | None = None
self.block_signature: str | None = None
self.signature_emitted: bool = False
self.tool_use_id: str | None = None
def reset(self) -> None:
self.block_type = None
self.block_index = None
self.block_signature = None
self.signature_emitted = False
self.tool_use_id = None
def start(self, block: AnthropicContentBlock) -> None:
self.block_type = block.type
self.block_index = self.content_block_index
if block.type == "thinking":
self.block_signature = uuid.uuid4().hex
self.signature_emitted = False
self.tool_use_id = None
elif block.type == "tool_use":
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = block.id
else:
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = None
first_item = True first_item = True
finish_reason = None finish_reason = None
# content_block_index = 0
# content_block_started = False
content_block_index = 0 content_block_index = 0
content_block_started = False active_block_type: str | None = None
active_block_index: int | None = None
active_block_signature: str | None = None
signature_emitted = False
active_tool_use_id: str | None = None
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
events: list[str] = []
if active_block_type is None:
return events
if (
active_block_type == "thinking"
and active_block_signature is not None
and not signature_emitted
):
chunk = AnthropicStreamEvent(
index=active_block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=active_block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=active_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
active_block_type = None
active_block_index = None
active_block_signature = None
signature_emitted = False
active_tool_use_id = None
content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
nonlocal active_block_type, active_block_index, content_block_index
nonlocal active_block_signature, signature_emitted, active_tool_use_id
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
active_block_type = block.type
active_block_index = content_block_index
if block.type == "thinking":
active_block_signature = uuid.uuid4().hex
signature_emitted = False
active_tool_use_id = None
elif block.type == "tool_use":
active_block_signature = None
signature_emitted = True
active_tool_use_id = block.id
else:
active_block_signature = None
signature_emitted = True
active_tool_use_id = None
return event
async for item in generator: async for item in generator:
if item.startswith("data:"): if item.startswith("data:"):
...@@ -326,6 +441,8 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -326,6 +441,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id, id=origin_chunk.id,
content=[], content=[],
model=origin_chunk.model, model=origin_chunk.model,
stop_reason=None,
stop_sequence=None,
usage=AnthropicUsage( usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage if origin_chunk.usage
...@@ -341,13 +458,33 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -341,13 +458,33 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info # last chunk including usage info
if len(origin_chunk.choices) == 0: if len(origin_chunk.choices) == 0:
if content_block_started: # if content_block_started:
stop_chunk = AnthropicStreamEvent( # stop_chunk = AnthropicStreamEvent(
index=content_block_index, # index=content_block_index,
type="content_block_stop", # type="content_block_stop",
) # )
data = stop_chunk.model_dump_json(exclude_unset=True) # data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop") # yield wrap_data_with_event(data, "content_block_stop")
# stop_reason = self.stop_reason_map.get(
# finish_reason or "stop"
# )
# chunk = AnthropicStreamEvent(
# type="message_delta",
# delta=AnthropicDelta(stop_reason=stop_reason),
# usage=AnthropicUsage(
# input_tokens=origin_chunk.usage.prompt_tokens
# if origin_chunk.usage
# else 0,
# output_tokens=origin_chunk.usage.completion_tokens
# if origin_chunk.usage
# else 0,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "message_delta")
# continue
for event in stop_active_block():
yield event
stop_reason = self.stop_reason_map.get( stop_reason = self.stop_reason_map.get(
finish_reason or "stop" finish_reason or "stop"
) )
...@@ -366,86 +503,218 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -366,86 +503,218 @@ class AnthropicServingMessages(OpenAIServingChat):
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "message_delta") yield wrap_data_with_event(data, "message_delta")
continue continue
# =========================================================
if origin_chunk.choices[0].finish_reason is not None: if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason finish_reason = origin_chunk.choices[0].finish_reason
continue # continue
# content # content
if origin_chunk.choices[0].delta.content is not None: # if origin_chunk.choices[0].delta.content is not None:
if not content_block_started: # if not content_block_started:
chunk = AnthropicStreamEvent( # chunk = AnthropicStreamEvent(
index=content_block_index, # index=content_block_index,
type="content_block_start", # type="content_block_start",
content_block=AnthropicContentBlock( # content_block=AnthropicContentBlock(
type="text", text="" # type="text", text=""
), # ),
) # )
data = chunk.model_dump_json(exclude_unset=True) # data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start") # yield wrap_data_with_event(data, "content_block_start")
content_block_started = True # content_block_started = True
if origin_chunk.choices[0].delta.content == "": # if origin_chunk.choices[0].delta.content == "":
continue # continue
chunk = AnthropicStreamEvent( # chunk = AnthropicStreamEvent(
index=content_block_index, # index=content_block_index,
type="content_block_delta", # type="content_block_delta",
delta=AnthropicDelta( # delta=AnthropicDelta(
type="text_delta", # type="text_delta",
text=origin_chunk.choices[0].delta.content, # text=origin_chunk.choices[0].delta.content,
), # ),
) # )
data = chunk.model_dump_json(exclude_unset=True) # data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta") # yield wrap_data_with_event(data, "content_block_delta")
continue # continue
# tool calls # tool calls
elif len(origin_chunk.choices[0].delta.tool_calls) > 0: # elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0] # elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
if tool_call.id is not None: # tool_call = origin_chunk.choices[0].delta.tool_calls[0]
if content_block_started: # if tool_call.id is not None:
stop_chunk = AnthropicStreamEvent( # if content_block_started:
index=content_block_index, # stop_chunk = AnthropicStreamEvent(
type="content_block_stop", # index=content_block_index,
) # type="content_block_stop",
data = stop_chunk.model_dump_json( # )
exclude_unset=True # data = stop_chunk.model_dump_json(
) # exclude_unset=True
yield wrap_data_with_event( # )
data, "content_block_stop" # yield wrap_data_with_event(
# data, "content_block_stop"
# )
# content_block_started = False
# content_block_index += 1
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_start",
# content_block=AnthropicContentBlock(
# type="tool_use",
# id=tool_call.id,
# name=tool_call.function.name
# if tool_call.function
# else None,
# input={},
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_start")
# content_block_started = True
# else:
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_delta",
# delta=AnthropicDelta(
# type="input_json_delta",
# partial_json=tool_call.function.arguments
# if tool_call.function
# else None,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_delta")
# continue
# thinking / text content
reasoning_delta = origin_chunk.choices[0].delta.reasoning
if reasoning_delta is not None:
if reasoning_delta == "":
pass
else:
if active_block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="thinking", thinking=""
)
) )
content_block_started = False yield start_event
content_block_index += 1
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
type="content_block_start", active_block_index
content_block=AnthropicContentBlock( if active_block_index is not None
type="tool_use", else content_block_index
id=tool_call.id, ),
name=tool_call.function.name type="content_block_delta",
if tool_call.function delta=AnthropicDelta(
else None, type="thinking_delta",
input={}, thinking=reasoning_delta,
), ),
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start") yield wrap_data_with_event(data, "content_block_delta")
content_block_started = True
if origin_chunk.choices[0].delta.content is not None:
if origin_chunk.choices[0].delta.content == "":
pass
else: else:
if active_block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(type="text", text="")
)
yield start_event
chunk = AnthropicStreamEvent( chunk = AnthropicStreamEvent(
index=content_block_index, index=(
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta", type="content_block_delta",
delta=AnthropicDelta( delta=AnthropicDelta(
type="input_json_delta", type="text_delta",
partial_json=tool_call.function.arguments text=origin_chunk.choices[0].delta.content,
if tool_call.function
else None,
), ),
) )
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta") yield wrap_data_with_event(data, "content_block_delta")
# tool calls - process all tool calls in the delta
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
for tool_call in origin_chunk.choices[0].delta.tool_calls:
if tool_call.id is not None:
# Update mapping for incremental updates
tool_index_to_id[tool_call.index] = tool_call.id
# Only create new block if different tool call
# AND has a name
tool_name = (
tool_call.function.name
if tool_call.function
else None
)
if (
active_tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_name,
input={},
)
)
yield start_event
# Handle initial arguments if present
if (
tool_call.function
and tool_call.function.arguments
and active_tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
else:
# Incremental update - use index to find tool_use_id
tool_use_id = tool_index_to_id.get(tool_call.index)
if (
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and active_tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=(
active_block_index
if active_block_index is not None
else content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(
data, "content_block_delta"
)
continue continue
else: else:
error_response = AnthropicStreamEvent( error_response = AnthropicStreamEvent(
...@@ -468,3 +737,31 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -468,3 +737,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True) data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error") yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def count_tokens(
self,
request: AnthropicCountTokensRequest,
raw_request: Request | None = None,
) -> AnthropicCountTokensResponse | ErrorResponse:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req = self._convert_anthropic_to_openai_request(request)
result = await self.render_chat_request(chat_req)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts
if "prompt_token_ids" in prompt
)
response = AnthropicCountTokensResponse(
input_tokens=input_tokens,
context_management=AnthropicContextManagement(
original_input_tokens=input_tokens
),
)
return response
\ No newline at end of file
...@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing):
index = 0 index = 0
if ( if (
self._should_check_for_unstreamed_tool_arg_tokens( # self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output # delta_message, output
tool_parser
and self._should_check_for_unstreamed_tool_arg_tokens(
delta_message, output, tool_parser
) )
and tool_parser # and tool_parser
): ):
latest_delta_len = 0 latest_delta_len = 0
if ( if (
...@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing):
latest_delta_len = len( latest_delta_len = len(
delta_message.tool_calls[0].function.arguments delta_message.tool_calls[0].function.arguments
) )
# get the expected call based on partial JSON # get the expected call based on partial JSON
# parsing which "autocompletes" the JSON # parsing which "autocompletes" the JSON.
expected_call = json.dumps( # Tool parsers (e.g. Qwen3Coder) store
tool_parser.prev_tool_call_arr[index].get( # arguments as a JSON string in
"arguments", {} # prev_tool_call_arr. Calling json.dumps()
), # on an already-serialized string would
ensure_ascii=False, # double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# expected_call = json.dumps(
# tool_parser.prev_tool_call_arr[index].get(
# "arguments", {}
# ),
# ensure_ascii=False,
# )
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
) )
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool
...@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
delta_message: DeltaMessage | None, delta_message: DeltaMessage | None,
output: CompletionOutput, output: CompletionOutput,
tool_parser: ToolParser | None = None,
) -> bool: ) -> bool:
""" """
Check to see if we should check for unstreamed tool arguments tokens. Check to see if we should check for unstreamed tool arguments tokens.
...@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0] and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None and delta_message.tool_calls[0].function.arguments is not None
and tool_parser is not None
and tool_parser.parser_should_check_for_unstreamed_tool_arg_tokens()
) )
@staticmethod @staticmethod
......
...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple): ...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
""" """
Whether this batch has active LoRA adapters. Whether this batch has active LoRA adapters.
""" """
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
""" """
...@@ -191,7 +199,7 @@ class ForwardContext: ...@@ -191,7 +199,7 @@ class ForwardContext:
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
""" """
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata attention layer to its attention metadata
Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
for each microbatch. for each microbatch.
......
...@@ -44,6 +44,23 @@ if TYPE_CHECKING: ...@@ -44,6 +44,23 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
......
...@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
shared_output=None,
routed_scaling_factor=None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel is not None assert self.kernel is not None
assert not self.is_monolithic assert not self.is_monolithic
......
...@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.enorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps) self.hnorm = GemmaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.shared_head = SharedHead(config=config, quant_config=quant_config) # self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.lm_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = Step3p5DecoderLayer( self.mtp_block = Step3p5DecoderLayer(
vllm_config, vllm_config,
prefix=f"{prefix}.mtp_block", prefix=f"{prefix}.mtp_block",
...@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module): ...@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
embed_tokens: VocabParallelEmbedding | None = None,
spec_step_index: int = 0, spec_step_index: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
assert inputs_embeds is not None if inputs_embeds is None:
assert embed_tokens is not None
inputs_embeds = embed_tokens(input_ids)
# assert inputs_embeds is not None
inputs_embeds = self.enorm(inputs_embeds) inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states) previous_hidden_states = self.hnorm(previous_hidden_states)
...@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict( self.layers = torch.nn.ModuleDict(
{ {
str(idx): Step3p5AMultiTokenPredictorLayer( str(idx): Step3p5AMultiTokenPredictorLayer(
vllm_config, # vllm_config,
f"{prefix}.layers.{idx}", # f"{prefix}.layers.{idx}",
vllm_config=vllm_config,
prefix=f"{prefix}.layers.{idx}",
) )
for idx in range( for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx,
...@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0, spec_step_idx: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None: # if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) # inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
inputs_embeds, inputs_embeds,
self.embed_tokens,
current_step_idx, current_step_idx,
) )
...@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module): ...@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor( logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states) # mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer.lm_head.head, mtp_layer.lm_head(hidden_states)
) )
return logits return logits
...@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module): ...@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module):
name = name.replace(".transformer.", ".") name = name.replace(".transformer.", ".")
if "shared_head" in name: if "shared_head" in name:
name = name.replace("shared_head.output", "shared_head.head") name = name.replace("shared_head.output", "shared_head.head")
name = name.replace("shared_head", "lm_head")
if "embed_tokens" in name: if "embed_tokens" in name:
assert ( assert (
hasattr(self.config, "num_nextn_predict_layers") hasattr(self.config, "num_nextn_predict_layers")
......
...@@ -118,6 +118,11 @@ class ToolParser: ...@@ -118,6 +118,11 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!" "AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
) )
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return True
class ToolParserManager: class ToolParserManager:
""" """
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import json import json
import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
from xml.parsers.expat import ParserCreate # from xml.parsers.expat import ParserCreate
import regex as re import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id # from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam, ChatCompletionToolsParam,
...@@ -25,1487 +26,1142 @@ from vllm.logger import init_logger ...@@ -25,1487 +26,1142 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
ToolParserManager,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
class StreamingXMLToolCallParser: class Step3p5ToolParser(ToolParser):
""" def __init__(self, tokenizer: TokenizerLike):
Simplified streaming XML tool call parser super().__init__(tokenizer)
Supports streaming input, parsing, and output
"""
def __init__(self): self.current_tool_name_sent: bool = False
self.reset_streaming_state() self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
# Tool configuration information # Sentinel tokens for streaming mode
self.tools: list[ChatCompletionToolsParam] | None = None
self.tool_call_start_token: str = "<tool_call>" self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>" self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function=" self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>" self.function_end_token: str = "</function>"
self.parameter_start_token: str = "<parameter=" self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>" self.parameter_end_token: str = "</parameter>"
self.is_tool_call_started: bool = False
self.failed_count: int = 0
def reset_streaming_state(self): # Enhanced streaming state - reset for each new message
"""Reset streaming parsing state""" self._reset_streaming_state()
self.deltas = []
# state for streaming
self.tool_call_index = 0
self.current_call_id = None
self.last_completed_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.streaming_buffer = ""
self.last_processed_pos = 0
self.text_content_buffer = ""
# state for preprocessing and deferred parsing
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
# recreate parser
self.parser = ParserCreate()
self.setup_parser()
def parse_single_streaming_chunks(self, xml_chunk: str) -> DeltaMessage:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
Args: # Regex patterns
xml_chunk: Single XML chunk string self.tool_call_complete_regex = re.compile(
Returns: r"<tool_call>(.*?)</tool_call>", re.DOTALL
DeltaMessage: Contains delta information generated by this chunk, )
returns empty response if no complete elements self.tool_call_function_regex = re.compile(
r"<function(?:=|\s+)?(.*?)</function>", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self.eos_token_id = getattr(self.model_tokenizer, "eos_token_id", None)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
""" """
# Record delta count before processing Skip the remaining_call calculation in serving
initial_delta_count = len(self.deltas) """
return False
self.streaming_buffer += xml_chunk def _reset_streaming_state(self):
"""Reset all streaming state for a new request."""
self._processed_length: int = 0 # Position of last processed character
self._tool_call_index: int = 0 # Number of tool calls processed so far
self.streaming_request = None # Current request being processed
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
found_elements = self._process_complete_xml_elements() if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
param_name,
func_name,
)
return param_value
if found_elements: if (
# If complete elements found, check if end events were missed isinstance(param_config[param_name], dict)
# some tags may not have been triggered and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try: try:
new_deltas = self.deltas[initial_delta_count:] return int(param_value)
# If this chunk contains </function> except (ValueError, TypeError):
# but didn't generate '}', then complete it try:
if ( float_value = float(param_value)
self.current_call_id is not None if float_value.is_integer():
and self.function_end_token in xml_chunk return int(float_value)
): except (ValueError, TypeError):
# - Added '}' (non-empty parameter ending) pass
# - Added '{}' (empty parameter function) try:
has_function_close = any( literal_value = ast.literal_eval(param_value)
( if isinstance(literal_value, bool):
td.tool_calls return int(literal_value)
and any( if isinstance(literal_value, (int, float)):
( return (
tc.function int(literal_value)
and tc.id == self.current_call_id if float(literal_value).is_integer()
and isinstance(tc.function.arguments, str) else literal_value
and (tc.function.arguments in ("}", "{}"))
)
for tc in td.tool_calls
)
) )
for td in new_deltas except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (int, float)):
return (
float(literal_value)
if float(literal_value) - int(float(literal_value)) != 0
else int(float(literal_value))
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
normalized_value = param_value.strip().lower()
if normalized_value in ["true", "false"]:
return normalized_value == "true"
if normalized_value in ["1", "0"]:
return normalized_value == "1"
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
else:
if (
param_type in ["object", "array", "arr"]
or param_type.startswith("dict")
or param_type.startswith("list")
):
try:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (list, dict)):
return literal_value
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
) )
if not has_function_close: return param_value
# Close potentially unclosed element try:
if self.current_param_name: literal_value = ast.literal_eval(param_value) # safer
self._end_element("parameter") if isinstance(literal_value, (tuple, set)):
if self.current_function_name: return list(literal_value)
self._end_element("function")
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
if ( if (
self.current_call_id is not None isinstance(literal_value, (list, dict, str, int, float, bool))
and self.tool_call_end_token in xml_chunk or literal_value is None
): ):
has_toolcall_close = any( return literal_value
( except (ValueError, SyntaxError, TypeError):
td.tool_calls pass
and any( logger.warning(
( "Parsed value '%s' of parameter '%s' cannot be converted via "
tc.type == "function" "Python `ast.literal_eval()` in tool '%s', returning raw string.",
and tc.function param_value,
and tc.function.arguments == "" param_name,
and tc.id == self.current_call_id func_name,
)
for tc in td.tool_calls
)
)
for td in new_deltas
)
if not has_toolcall_close:
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.current_function_name:
self._end_element("function")
self._end_element("tool_call")
except Exception as e:
logger.warning("Error with fallback parsing: %s", e)
# Merge newly generated deltas into single response
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
) )
return result_delta return param_value
def _parse_parameters_fallback(
self,
parameters: str,
allowed_param_names: set[str] | None = None,
) -> list[tuple[str, str]]:
"""Fallback parser for malformed parameter tags."""
param_pairs: list[tuple[str, str]] = []
pos = 0
while True:
start = parameters.find(self.parameter_prefix, pos)
if start == -1:
break
name_start = start + len(self.parameter_prefix)
name_end = parameters.find(">", name_start)
if name_end == -1:
newline_idx = parameters.find("\n", name_start)
end_tag = parameters.find(self.parameter_end_token, name_start)
next_param = parameters.find(self.parameter_prefix, name_start)
candidates = [
idx for idx in [newline_idx, end_tag, next_param] if idx != -1
]
if not candidates:
break
name_end = min(candidates)
value_start = name_end
else:
value_start = name_end + 1
param_name = parameters[name_start:name_end].strip()
next_param = parameters.find(self.parameter_prefix, value_start)
end_tag = parameters.find(self.parameter_end_token, value_start)
if end_tag == -1 or (next_param != -1 and next_param < end_tag):
end = next_param if next_param != -1 else len(parameters)
pos = end
else:
end = end_tag
pos = end + len(self.parameter_end_token)
param_value = parameters[value_start:end]
if allowed_param_names is None or param_name in allowed_param_names:
param_pairs.append((param_name, param_value))
return param_pairs
def _is_valid_json_arguments(self, arguments: str) -> bool:
"""Check if arguments can be loaded as JSON."""
try:
json.loads(arguments)
except Exception:
return False
return True
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
# check empty function name
function_name = function_call_str[:end_index].strip()
if function_name.startswith("="):
function_name = function_name.lstrip("=").strip()
if not function_name or function_name.strip("'\"") == "":
logger.warning("Empty function name in tool call.")
return None
if function_name[0] in "\"'" and function_name[-1] == function_name[0]:
function_name = function_name[1:-1].strip()
if not function_name:
logger.warning("Empty function name in tool call.")
return None
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
match_texts = self.tool_call_parameter_regex.findall(parameters)
use_fallback = False
if match_texts:
for match_text in match_texts:
if self.parameter_prefix in match_text or ">" not in match_text:
use_fallback = True
break
else: else:
# No complete elements, check if there's unoutput text content use_fallback = self.parameter_prefix in parameters
if self.text_content_buffer and self.tool_call_index == 0:
# Has text content but no tool_call yet, output text content
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer to avoid duplicate output
self.text_content_buffer = ""
return text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if self.current_call_id is not None and (
self.function_end_token in xml_chunk
or self.tool_call_end_token in xml_chunk
):
# Close potentially unclosed element
if self.current_param_name:
self._end_element("parameter")
if self.function_end_token in xml_chunk and self.current_function_name:
self._end_element("function")
if self.tool_call_end_token in xml_chunk:
self._end_element("tool_call")
# Return the merged delta result generated by this fallback
result_delta = self._merge_new_deltas_to_single_response(
initial_delta_count
)
return result_delta
# No complete elements, return empty response if use_fallback:
return DeltaMessage(content=None) allowed_param_names = (
set(param_config.keys())
if isinstance(param_config, dict) and param_config
else None
)
param_pairs = self._parse_parameters_fallback(
parameters, allowed_param_names
)
else:
param_pairs = []
for match_text in match_texts:
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
param_pairs.append((param_name, param_value))
for param_name, param_value in param_pairs:
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
def _escape_xml_special_chars(self, text: str) -> str: try:
""" arguments = json.dumps(param_dict, ensure_ascii=False)
Escape XML special characters except Exception as e:
Args: logger.warning("Error in converting parameter value: %s", e)
text: Original text return None
Returns: return ToolCall(
Escaped text type="function",
""" function=FunctionCall(name=function_name, arguments=arguments),
xml_escapes = { )
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&apos;",
}
for char, escape in xml_escapes.items(): def _get_function_calls(self, model_output: str) -> list[str]:
text = text.replace(char, escape) # Find all tool calls
raw_tool_calls = self.tool_call_complete_regex.findall(model_output)
return text # if no closed tool_call tags found, return empty list
if len(raw_tool_calls) == 0:
return []
def _process_complete_xml_elements(self) -> bool: raw_function_calls = []
""" for tool_call in raw_tool_calls:
Process complete XML elements in buffer function_matches = self.tool_call_function_regex.findall(tool_call)
raw_function_calls.extend(function_matches)
Returns: return raw_function_calls
bool: Whether complete elements were found and processed
"""
found_any = False
while self.last_processed_pos < len(self.streaming_buffer): def _check_format(self, model_output: str) -> bool:
# Find next complete xml element """Check if model output contains properly formatted tool call.
element, end_pos = self._find_next_complete_element(self.last_processed_pos)
if element is None:
# No complete element found, wait for more data
break
# Check if this element should be skipped Requirements:
if self._should_skip_element(element): 1. Must have closed tool_call tags (<tool_call>...</tool_call>)
self.last_processed_pos = end_pos 2. Must have closed function tags (<function=...</function>)
continue 3. If parameter tags exist, they must be closed and correct
# Found complete XML element, process it Returns True if the format is valid, False otherwise.
try: """
preprocessed_element = self._preprocess_xml_chunk(element) # Check 1: Must have closed tool_call tags
# Check if this is the first tool_call start tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if ( if len(tool_call_matches) == 0:
( return False
preprocessed_element.strip().startswith("<tool_call>")
or preprocessed_element.strip().startswith("<function name=")
)
and self.tool_call_index == 0
) and self.text_content_buffer:
# First tool_call starts,
# output previously collected text content first
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer for potential subsequent text content
self.text_content_buffer = ""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if (
preprocessed_element.strip().startswith("<tool_call>")
and self.tool_call_index > 0
and self.current_call_id
and self.current_function_name
):
# Reset parser state but preserve generated deltas
if self.current_param_name:
self._end_element("parameter")
if self.current_function_open:
self._end_element("function")
# Output final tool_call tail delta
final_delta = DeltaMessage(
role=None,
content=None,
reasoning_content=None,
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
],
)
self._emit_delta(final_delta)
# Reset XML parser and current call state
self._reset_xml_parser_after_tool_call()
# Parse preprocessed element
self.parser.Parse(preprocessed_element, False)
found_any = True
except Exception as e: # Check 2: Must have closed function tags within tool_call
logger.warning("Error when parsing XML elements: %s", e) has_valid_function = False
for tool_call_content in tool_call_matches:
function_matches = self.tool_call_function_regex.findall(tool_call_content)
if len(function_matches) > 0:
has_valid_function = True
# Check if there's an unclosed function tag
if (
self.tool_call_prefix in tool_call_content
and self.function_end_token not in tool_call_content
):
return False
# Update processed position if not has_valid_function:
self.last_processed_pos = end_pos return False
return found_any # Check 3: If parameter tags exist, they must be closed and correct
for tool_call_content in tool_call_matches:
# Count opening and closing parameter tags
param_open_count = tool_call_content.count(self.parameter_prefix)
param_close_count = tool_call_content.count(self.parameter_end_token)
# If there are parameter tags, they must be balanced
if param_open_count > 0:
if param_open_count != param_close_count:
return False
# Check if all parameter tags are properly closed using regex
param_matches = self.tool_call_parameter_regex.findall(
tool_call_content
)
if len(param_matches) != param_open_count:
return False
def _fix_incomplete_tag_in_chunk(self, chunk: str) -> str: return True
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk = self._fix_missing_equals_in_function_tag(chunk)
for tag_type in ["parameter", "function"]: def _wrap_missing_tool_call_tags(self, model_output: str) -> str:
pattern = f"<{tag_type}=" """Wrap bare <function=...></function> blocks with <tool_call> tags."""
if pattern not in chunk: if (
self.tool_call_prefix not in model_output
or self.function_end_token not in model_output
):
return model_output
def _wrap_bare_functions(text: str) -> str:
pos = 0
wrapped_parts: list[str] = []
while True:
func_idx = text.find(self.tool_call_prefix, pos)
if func_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx = text.find(self.function_end_token, func_idx)
if end_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx += len(self.function_end_token)
wrapped_parts.append(text[pos:func_idx])
wrapped_parts.append(self.tool_call_start_token)
wrapped_parts.append(text[func_idx:end_idx])
wrapped_parts.append(self.tool_call_end_token)
ws_idx = end_idx
while ws_idx < len(text) and text[ws_idx].isspace():
ws_idx += 1
if text.startswith(self.tool_call_end_token, ws_idx):
if ws_idx > end_idx:
wrapped_parts.append(text[end_idx:ws_idx])
pos = ws_idx + len(self.tool_call_end_token)
else:
pos = end_idx
return "".join(wrapped_parts)
tool_call_ranges = [
match.span()
for match in self.tool_call_complete_regex.finditer(model_output)
]
if not tool_call_ranges:
return _wrap_bare_functions(model_output)
wrapped_parts: list[str] = []
pos = 0
for start, end in tool_call_ranges:
if start < pos:
continue continue
wrapped_parts.append(_wrap_bare_functions(model_output[pos:start]))
wrapped_parts.append(model_output[start:end])
pos = end
wrapped_parts.append(_wrap_bare_functions(model_output[pos:]))
return "".join(wrapped_parts)
def _normalize_prev_arguments(self, args_value: Any) -> Any:
if isinstance(args_value, str):
try:
return json.loads(args_value)
except (TypeError, ValueError, json.JSONDecodeError):
return args_value
return args_value
def _update_prev_tool_call_state(self, tool_calls: list[ToolCall]) -> None:
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
for tool_call in tool_calls:
if not tool_call or not tool_call.function:
continue
args_value = tool_call.function.arguments
if isinstance(args_value, str):
args_json = args_value
elif args_value is None:
args_json = ""
else:
try:
args_json = json.dumps(args_value, ensure_ascii=False)
except (TypeError, ValueError):
args_json = str(args_value)
prev_args = self._normalize_prev_arguments(args_json)
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": prev_args,
}
)
try:
expected_args_json = json.dumps(prev_args, ensure_ascii=False)
except (TypeError, ValueError):
expected_args_json = args_json
start_idx = chunk.find(pattern) # Serving may subtract the latest delta length from
after_tag = chunk[start_idx:] # streamed_args_for_tool to detect unstreamed suffixes. Since this
gt_pos = after_tag.find(">") # parser emits full arguments at once, store expected+actual so
lt_pos = after_tag.find("<", len(pattern)) # the subtraction yields expected_args_json and no resend occurs.
self.streamed_args_for_tool.append(expected_args_json + args_json)
# Skip if already well-formed def extract_tool_calls(
if ( self,
gt_pos != -1 model_output: str,
and (lt_pos == -1 or gt_pos < lt_pos) request: ChatCompletionRequest,
and pattern in after_tag[:gt_pos] ) -> ExtractedToolCallInformation:
): try:
continue origin_model_output = model_output
try:
# Fallback: handle outputs without <tool_call> wrapper.
origin_model_output = self._wrap_missing_tool_call_tags(
origin_model_output
)
model_output = origin_model_output
except Exception:
pass
# Use streaming-like approach: process position by position
valid_tool_calls = []
content_parts = []
processed_length = 0
while processed_length < len(model_output):
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
model_output, processed_length
)
# Extract tag name (stop at space, newline, or <) # Case 1: No more tool calls - add remaining as content
content = chunk[start_idx + len(pattern) :] if tool_start_idx == -1:
end_pos = next( remaining = model_output[processed_length:]
(i for i, ch in enumerate(content) if ch in (" ", "\n", "<")), if remaining:
len(content), content_parts.append(remaining)
) break
tag_name = content[:end_pos]
# Case 2: Content before tool call
if tool_start_idx > processed_length:
content_before = model_output[processed_length:tool_start_idx]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if processed_length > 0:
text_before = model_output[:processed_length]
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content_before.strip() == ""
):
# Skip whitespace between tool calls
pass
else:
content_parts.append(content_before)
else:
content_parts.append(content_before)
if not tag_name: # Case 3: Try to find complete tool call
continue tool_end_idx = self._find_first_complete_tool_call_end(
model_output, tool_start_idx
)
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx # If tool call is incomplete - add remaining as content and stop
if tag_name.startswith(f"{tag_type}="): if tool_end_idx == -1:
tag_name = tag_name[len(tag_type) + 1 :] remaining = model_output[tool_start_idx:]
if remaining:
content_parts.append(remaining)
break
# Extract and try to parse the complete tool call
tool_call_text = model_output[tool_start_idx:tool_end_idx]
parsed_result = self.extract_tool_calls_basic(tool_call_text, request)
# If parsing succeeded, record the tool call(s)
if parsed_result.tools_called and parsed_result.tool_calls:
valid_tool_calls.extend(parsed_result.tool_calls)
processed_length = tool_end_idx
else:
# Parsing failed - treat this tool call as content
content_parts.append(tool_call_text)
processed_length = tool_end_idx
# Remove trailing non-alphanumeric chars (keep - and _) # Populate prev_tool_call_arr for serving layer to set finish_reason
while tag_name and not ( self._update_prev_tool_call_state(valid_tool_calls)
tag_name[-1].isalnum() or tag_name[-1] in ("-", "_")
):
tag_name = tag_name[:-1]
if not tag_name: # Combine content parts
continue content = "".join(content_parts) if content_parts else None
# Validate parameter exists in tool definition return ExtractedToolCallInformation(
if tag_type == "parameter" and not self._validate_parameter_name(tag_name): tools_called=(len(valid_tool_calls) > 0),
continue tool_calls=valid_tool_calls,
content=content if content else None,
)
# Apply fix except Exception:
chunk = chunk.replace( logger.warning("Error in extracting tool call from response.")
f"<{tag_type}={content[:end_pos]}", f"<{tag_type}={tag_name}>", 1 return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
) )
return chunk def extract_tool_calls_basic(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
model_output = self._wrap_missing_tool_call_tags(model_output)
# Quick check to avoid unnecessary processing
if not self._check_format(model_output):
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _fix_missing_equals_in_function_tag(self, chunk: str) -> str: try:
""" function_calls = self._get_function_calls(model_output)
Fix missing = in function tags: <function xxx> or <functionxxx> if len(function_calls) == 0:
Examples: return ExtractedToolCallInformation(
<function execute_bash> -> <function=execute_bash> tools_called=False, tool_calls=[], content=model_output
<functionexecute_bash> -> <function=execute_bash> )
Only fixes if function name exists in tool definition
"""
# already correct
if "<function=" in chunk:
return chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1 = r"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1 = re.search(pattern1, chunk)
if match1:
func_name = match1.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match1.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2 = r"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2 = re.search(pattern2, chunk)
if match2:
func_name = match2.group(1).strip()
# must validate function name exists before fixing
if func_name and self._validate_function_name(func_name):
original = match2.group(0)
fixed = f"<function={func_name}>"
chunk = chunk.replace(original, fixed, 1)
return chunk
return chunk
def _validate_function_name(self, func_name: str) -> bool:
"""Check if function name exists in tool definitions"""
if not self.tools:
return False
for tool in self.tools: tool_calls: list[ToolCall] = []
if ( for function_call_str in function_calls:
hasattr(tool, "type") tool_call = self._parse_xml_function_call(
and tool.type == "function" function_call_str, request.tools
and hasattr(tool, "function") )
and hasattr(tool.function, "name") if tool_call:
and tool.function.name == func_name tool_calls.append(tool_call)
): if not tool_calls:
return True return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
for tool_call in tool_calls:
if (
not tool_call.function
or tool_call.function.arguments is None
or not self._is_valid_json_arguments(tool_call.function.arguments)
):
logger.warning(
"Invalid JSON arguments in tool call, falling back to content."
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
return False # Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(tool_calls)
def _validate_parameter_name(self, param_name: str) -> bool: # Extract content before tool calls
"""Check if parameter exists in current function's tool definition""" content_index = model_output.find(self.tool_call_start_token)
if not self.tools or not self.current_function_name: content = model_output[:content_index] # .rstrip()
return True
for tool in self.tools: return ExtractedToolCallInformation(
if ( tools_called=(len(tool_calls) > 0),
hasattr(tool, "type") tool_calls=tool_calls,
and tool.type == "function" content=content if content else None,
and hasattr(tool, "function") )
and hasattr(tool.function, "name")
and tool.function.name == self.current_function_name
):
if not hasattr(tool.function, "parameters"):
return True
params = tool.function.parameters
if isinstance(params, dict):
properties = params.get("properties", params)
return param_name in properties
break
return True except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _should_skip_element(self, element: str) -> bool: def _find_first_complete_tool_call_end(self, text: str, start_pos: int = 0) -> int:
""" """Find the end position of the first complete tool call.
Determine whether an element should be skipped
Args: Args:
element: Element to evaluate text: Text to search in
start_pos: Position to start searching from
Returns: Returns:
bool: True means should skip, False means should process Position after the first </tool_call> tag, or -1 if incomplete
"""
# If it's a tool_call XML tag, don't skip Example:
if ( "<tool_call>...</tool_call>..." returns position after </tool_call>
element.startswith(self.tool_call_start_token) """
or element.startswith(self.function_start_token) # Find tool call start
or element.startswith(self.parameter_start_token) start_idx = text.find(self.tool_call_start_token, start_pos)
): if start_idx == -1:
return False return -1
# If currently not parsing tool calls and not blank, # Find matching end token
# collect this text instead of skipping end_idx = text.find(
# Only process other XML elements after tool_call appears, self.tool_call_end_token, start_idx + len(self.tool_call_start_token)
# otherwise treat as plain text )
if self.current_call_id is None and element: if end_idx == -1:
# Collect text content to buffer return -1 # Incomplete tool call
self.text_content_buffer += element
return True # Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if self.current_call_id is not None:
return False
# Skip blank content # Return position after end token
return not element return end_idx + len(self.tool_call_end_token)
def _find_next_complete_element(self, start_pos: int) -> tuple[str | None, int]: def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int:
""" """Find the start position of next tool call.
Find next complete XML element from specified position
Args: Args:
start_pos: Position to start searching text: Text to search in
start_pos: Position to start searching from
Returns: Returns:
(Complete element string, element end position), Position of <tool_call> token, or -1 if not found
returns (None, start_pos) if no complete element found
""" """
buffer = self.streaming_buffer[start_pos:] return text.find(self.tool_call_start_token, start_pos)
if not buffer:
return None, start_pos
if buffer.startswith("<"): def _extract_content_between_tool_calls_list(self, text: str) -> list[str]:
# Check if this is an incomplete parameter/function tag """Extract content segments after each tool call.
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param = (
buffer.startswith("<parameter=") and ">" not in buffer.split("\n")[0]
)
is_incomplete_func = (
buffer.startswith("<function=") and ">" not in buffer.split("\n")[0]
)
if is_incomplete_param or is_incomplete_func: For n tool calls, returns n segments where segment[i] is the content
# Find the corresponding closing tag after tool_call[i] (before tool_call[i+1] or at the end).
tag_type = "parameter" if is_incomplete_param else "function"
closing_tag = f"</{tag_type}>"
closing_pos = buffer.find(closing_tag)
if closing_pos != -1:
# Found closing tag, return complete element including closing tag
complete_element = buffer[: closing_pos + len(closing_tag)]
return complete_element, start_pos + closing_pos + len(closing_tag)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end = buffer.find("<", 1)
tag_end2 = buffer.find(">", 1)
if tag_end != -1 and tag_end2 != -1:
# Next nearest is <
if tag_end < tag_end2:
return buffer[:tag_end], start_pos + tag_end
# Next nearest is >, means found XML element
else:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
elif tag_end != -1:
return buffer[:tag_end], start_pos + tag_end
elif tag_end2 != -1:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if self.current_call_id is None:
# Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data
return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else:
# Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer)
else:
# When parsing tool calls,
# wait for more data to get complete tag
return None, start_pos
else:
# Find text content (until next < or buffer end)
next_tag_pos = buffer.find("<")
if next_tag_pos != -1:
# Found text content
text_content = buffer[:next_tag_pos]
return text_content, start_pos + next_tag_pos
else:
# Buffer end is all text, process
# (no longer wait for more data)
remaining = buffer
return remaining, start_pos + len(remaining)
def _merge_new_deltas_to_single_response(self, initial_count: int) -> DeltaMessage: Empty or whitespace-only segments are represented as empty string "".
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Args: Args:
initial_count: Delta count before processing text: Text containing tool calls
Returns: Returns:
Merged DeltaMessage containing all newly generated delta information List of content segments (one per tool call)
""" """
if len(self.deltas) <= initial_count: content_segments = []
return DeltaMessage(content=None) pos = 0
# Get newly generated deltas while True:
new_deltas = self.deltas[initial_count:] # Find end of current tool call
end_pos = text.find(self.tool_call_end_token, pos)
if end_pos == -1:
break
if len(new_deltas) == 1: # Move past the end token
# Only one new delta, return directly end_pos += len(self.tool_call_end_token)
return new_deltas[0]
# Merge multiple new deltas # Find start of next tool call
merged_tool_calls: list[DeltaToolCall] = [] next_start = self._find_tool_call_start(text, end_pos)
merged_content: str = ""
for delta in new_deltas: # Extract content between current end and next start (or text end)
if delta.content: content = text[end_pos:next_start] if next_start != -1 else text[end_pos:]
merged_content += delta.content
if delta.tool_calls:
# For tool_calls, we need to intelligently merge arguments
for tool_call in delta.tool_calls:
# Find if there's already a tool_call with the same call_id
existing_call = None
for existing in merged_tool_calls:
if existing.id == tool_call.id:
existing_call = existing
break
if existing_call and existing_call.function:
# Merge to existing tool_call
if tool_call.function and tool_call.function.name:
existing_call.function.name = tool_call.function.name
if (
tool_call.function
and tool_call.function.arguments is not None
):
if existing_call.function.arguments is None:
existing_call.function.arguments = ""
# For streaming JSON parameters,
# simply concatenate in order
new_args = tool_call.function.arguments
existing_call.function.arguments += new_args
if tool_call.type:
existing_call.type = tool_call.type
else:
# Add new tool_call
merged_tool_calls.append(tool_call)
return DeltaMessage( # Store content (empty string if whitespace-only)
content=merged_content if merged_content else None, content_segments.append(content if content.strip() else "")
tool_calls=merged_tool_calls,
)
def _preprocess_xml_chunk(self, chunk: str) -> str: if next_start == -1:
""" break
Preprocess XML chunk, handle non-standard formats, pos = next_start
and escape special characters
Args: return content_segments
chunk: Original XML chunk
Returns: def _convert_tool_calls_to_deltas(
Processed XML chunk self, tool_calls: list[ToolCall], starting_index: int = 0
""" ) -> list[DeltaToolCall]:
"""Convert complete ToolCall list to DeltaToolCall list.
# Check if this is a tool_call related element
is_tool_call = False
if chunk.startswith(self.tool_call_start_token) or chunk.startswith(
self.tool_call_end_token
):
is_tool_call = True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if (
chunk.startswith(self.function_start_token)
or chunk.startswith(self.function_end_token)
or chunk.startswith("<function ")
or re.match(r"^<function[a-zA-Z_]", chunk)
): # <functionXXX without space or =
is_tool_call = True
if chunk.startswith(self.parameter_start_token) or chunk.startswith(
self.parameter_end_token
):
is_tool_call = True
# Fallback: fix incomplete <parameter= or <function= tags without Returns complete tool calls without splitting into fragments.
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if (
self.current_call_id is not None
or chunk.startswith("<function")
or chunk.startswith("<parameter")
):
chunk = self._fix_incomplete_tag_in_chunk(chunk)
# Handle <function=name> format -> <function name="name">
processed = re.sub(r"<function=([^>]+)>", r'<function name="\1">', chunk)
# Handle <parameter=name> format -> <parameter name="name">
processed = re.sub(r"<parameter=([^>]+)>", r'<parameter name="\1">', processed)
original_chunk = chunk
# If in parameter value accumulation mode
if self._pre_inside_parameter:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if processed.startswith("</parameter>"):
body_text = self._pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self.defer_current_parameter = True
self.deferred_param_raw_value = body_text
# Clean up state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
safe_text = self._escape_xml_special_chars(body_text)
return f"{safe_text}</parameter>"
else:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if self._pre_param_buffer == "":
# Get current parameter type
param_type = (
self._get_param_type(self._pre_current_param_name)
if self._pre_current_param_name
else "string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type = param_type in ["object"]
is_complex_type = (
param_type in ["array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
# Only delay when contains container symbols Args:
# and has single quotes and is complex type tool_calls: List of tool calls to convert
has_container_hint = ( starting_index: Starting index for tool calls (default 0)
("[" in original_chunk)
or ("{" in original_chunk)
or ("(" in original_chunk)
)
# Determine if deferred parsing is needed Returns:
need_defer = False List of DeltaToolCall with complete arguments
if is_complex_type:
# Complex type, always need deferred parsing
need_defer = True
elif (
is_object_type
and has_container_hint
and ("'" in original_chunk)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer = True
if not need_defer:
# No need for deferred parsing,
# exit parameter mode directly
self._pre_inside_parameter = False
return self._escape_xml_special_chars(original_chunk)
self._pre_param_buffer += original_chunk
return ""
# Parameter start: enable accumulation
if processed.startswith("<parameter name="):
m = re.match(r'<parameter name="([^"]+)">', processed)
if m:
self._pre_current_param_name = m.group(1)
self._pre_inside_parameter = True
self._pre_param_buffer = ""
return processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if not is_tool_call:
processed = self._escape_xml_special_chars(processed)
return processed
def _emit_delta(self, delta: DeltaMessage):
"""Emit Delta response (streaming output)"""
self.deltas.append(delta)
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
""" """
# First close unclosed parameters delta_tool_calls = []
if self.current_param_name: for i, tool_call in enumerate[ToolCall](tool_calls):
self._end_element("parameter") index = starting_index + i
tool_id = self._generate_tool_call_id()
# If about to start new function or tool_call,
# and there are unclosed functions, close function first # Create complete DeltaToolCall with full arguments
if incoming_tag in ("function", "tool_call") and self.current_function_name: delta_tool_calls.append(
self._end_element("function") DeltaToolCall(
index=index,
# If about to start new tool_call, id=tool_id,
# and there are unclosed tool_calls, close tool_call first function=DeltaFunctionCall(
if incoming_tag == "tool_call" and self.current_call_id: name=tool_call.function.name,
self._end_element("tool_call") arguments=tool_call.function.arguments,
),
def _start_element(self, name: str, attrs: dict[str, str]): type="function",
"""Handle XML start element events"""
if name == "root":
return
if name == "tool_call":
# Before opening new tool_call,
# automatically complete previous unclosed tags
self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {}
self.current_call_id = make_tool_call_id()
self.current_param_is_first = True
self.tool_call_index += 1
elif name.startswith("function") or (name == "function"):
# If missing tool_call, manually complete
if not self.current_call_id:
self._start_element("tool_call", {})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self._auto_close_open_parameter_if_needed("function")
function_name = self._extract_function_name(name, attrs)
self.current_function_name = function_name
self.current_function_open = True
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=function_name, arguments=""
),
)
]
)
self._emit_delta(delta)
elif name.startswith("parameter") or (name == "parameter"):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self._auto_close_open_parameter_if_needed("parameter")
param_name = self._extract_parameter_name(name, attrs)
self.current_param_name = param_name
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False # Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if param_name:
if not self.parameters:
# First parameter
# start JSON, only output parameter name and colon
json_start = f'{{"{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_start
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = True
else:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue = f', "{param_name}": '
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_continue
),
)
]
)
self._emit_delta(delta)
self.current_param_is_first = False
def _char_data(self, data: str):
"""Handle XML character data events"""
if data and self.current_param_name:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if self.defer_current_parameter:
original_data = data
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
return
param_type = self._get_param_type(self.current_param_name)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if not self.current_param_value and data.startswith("\n"):
data = data[1:]
# Output start quote for string type (if not already output)
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"]
and not self.start_quote_emitted
):
quote_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
) )
self._emit_delta(quote_delta)
self.start_quote_emitted = True
if not data:
return
original_data = data
# Delay output of trailing newline
if self.should_emit_end_newline:
original_data = "\n" + original_data
self.should_emit_end_newline = False
if original_data.endswith("\n"):
self.should_emit_end_newline = True
original_data = original_data[:-1]
self.current_param_value += original_data
# convert parameter value by param_type
converted_value = self._convert_param_value(
self.current_param_value, param_type
) )
output_data = self._convert_for_json_streaming(converted_value, param_type)
delta_data = output_data[len(self.current_param_value_converted) :]
self.current_param_value_converted = output_data
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=delta_data),
)
]
)
self._emit_delta(delta)
def _end_element(self, name: str):
"""Handle XML end element events"""
if name == "root": return delta_tool_calls
return
# If function or tool_call ends and there are still unclosed parameters, def extract_tool_calls_streaming(
# complete parameter end first self,
if ( previous_text: str,
name.startswith("function") or name == "function" or name == "tool_call" current_text: str,
) and self.current_param_name: delta_text: str,
self._auto_close_open_parameter_if_needed() previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# Check for EOS token
has_eos = (
self.eos_token_id is not None
and delta_token_ids
and self.eos_token_id in delta_token_ids
)
if ( # If no delta text, check if we need to return empty delta for finish_reason
name.startswith("parameter") or name == "parameter" if not delta_text and not has_eos:
) and self.current_param_name: # Check if this is an EOS token after all tool calls are complete
# End current parameter if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
param_name = self.current_param_name # Count complete tool calls
param_value = self.current_param_value complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if self.defer_current_parameter:
raw_text = (
self.deferred_param_raw_value
if self.deferred_param_raw_value
else param_value
)
parsed_value = None
output_arguments = None
try:
# If previously delayed trailing newline,
# add it back before parsing
if self.should_emit_end_newline:
raw_for_parse = raw_text + "\n"
else:
raw_for_parse = raw_text
parsed_value = ast.literal_eval(raw_for_parse)
output_arguments = json.dumps(parsed_value, ensure_ascii=False)
except Exception:
# Fallback: output as string as-is
output_arguments = json.dumps(raw_text, ensure_ascii=False)
parsed_value = raw_text
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(
name=None, arguments=output_arguments
),
)
]
)
self._emit_delta(delta)
# Clean up and store
self.should_emit_end_newline = False
self.parameters[param_name] = parsed_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
return
param_type = self._get_param_type(param_name)
# convert complete parameter value by param_type
converted_value = self._convert_param_value(param_value, param_type)
# Decide whether to add end quote based on parameter type
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# For empty string parameters, need special handling
if not param_value and not self.start_quote_emitted:
# No start quote output,
# directly output complete empty string
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='""'),
)
]
)
self._emit_delta(delta)
else:
# Non-empty parameter value, output end quote
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
]
)
self._emit_delta(delta)
self.should_emit_end_newline = False
# Store converted value
self.parameters[param_name] = converted_value
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.start_quote_emitted = False
elif name.startswith("function") or name == "function":
# if there are parameters, close JSON object
if self.parameters:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="}"),
)
]
)
self._emit_delta(delta)
# return empty object
else:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments="{}"),
)
]
) )
self._emit_delta(delta)
self.current_function_open = False
self.current_function_name = (
None # Clear function name to prevent duplicate closing
)
elif name == "tool_call": # If we have completed tool calls and populated prev_tool_call_arr
# Before ending tool_call, if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# ensure function is closed to complete missing right brace # Check if all tool calls are closed
if self.current_function_open: open_calls = current_text.count(
# If there are still unclosed parameters, close them first self.tool_call_start_token
if self.current_param_name: ) - current_text.count(self.tool_call_end_token)
self._end_element("parameter") if open_calls == 0:
# Close function, ensure output '}' or '{}' # Return empty delta for finish_reason processing
self._end_element("function") return DeltaMessage(content="")
# Final Delta return None
delta = DeltaMessage(
tool_calls=[ # Process all available content
DeltaToolCall( accumulated_deltas: list[DeltaMessage] = []
index=self.tool_call_index - 1,
id=self.current_call_id,
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
]
)
self._emit_delta(delta)
# Check if there's text content to output (between tool_calls) while self._has_unprocessed_content(current_text):
if self.text_content_buffer.strip(): # Try to process next chunk (content or tool call)
text_delta = DeltaMessage(content=self.text_content_buffer) delta = self._process_next_chunk(current_text)
self._emit_delta(text_delta)
self._reset_xml_parser_after_tool_call() if delta is None:
# Cannot proceed further, need more tokens
break
def setup_parser(self): # Accumulate deltas
"""Set up XML parser event handlers""" if isinstance(delta, list):
self.parser.buffer_text = True accumulated_deltas.extend(delta)
self.parser.StartElementHandler = self._start_element else:
self.parser.EndElementHandler = self._end_element accumulated_deltas.append(delta)
self.parser.CharacterDataHandler = self._char_data
# Handle EOS: flush any remaining incomplete tool calls as content
if has_eos:
remaining_delta = self._flush_remaining_content(current_text)
if remaining_delta:
accumulated_deltas.append(remaining_delta)
# If no remaining content but we have tool calls, return empty delta
elif len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
accumulated_deltas.append(DeltaMessage(content=""))
# Return results
return self._format_delta_result(accumulated_deltas)
def _has_unprocessed_content(self, current_text: str) -> bool:
"""Check if there's unprocessed content in the buffer."""
return self._processed_length < len(current_text)
def _process_next_chunk(
self, current_text: str
) -> DeltaMessage | list[DeltaMessage] | None:
"""Process next chunk: either regular content or a complete tool call.
def set_tools(self, tools: list[ChatCompletionToolsParam] | None): Args:
"""Set tool configuration information""" current_text: Current accumulated text
self.tools = tools
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: Returns:
"""Extract function name from various formats""" - DeltaMessage or list of DeltaMessage if processed successfully
if attrs and "name" in attrs: - None if cannot proceed (need more tokens)
return attrs["name"] """
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
current_text, self._processed_length
)
if "=" in name: # Case 1: No tool call found - return remaining content
parts = name.split("=", 1) if tool_start_idx == -1:
if len(parts) == 2 and parts[0] == "function": return self._process_content(
return parts[1] current_text, self._processed_length, len(current_text)
)
return None # Case 2: Content before tool call
if tool_start_idx > self._processed_length:
return self._process_content(
current_text, self._processed_length, tool_start_idx
)
def _extract_parameter_name(self, name: str, attrs: dict[str, str]) -> str | None: # Case 3: Tool call at current position
"""Extract parameter name from various formats""" # Find end of the first complete tool call
if attrs and "name" in attrs: tool_end_idx = self._find_first_complete_tool_call_end(
return attrs["name"] current_text, tool_start_idx
)
if "=" in name: if tool_end_idx == -1:
parts = name.split("=", 1) # Tool call incomplete, wait for more tokens
if len(parts) == 2 and parts[0] == "parameter": return None
return parts[1]
return None # Process complete tool call
return self._process_complete_tool_calls(
current_text, tool_start_idx, tool_end_idx
)
def _process_content(
self, current_text: str, start_pos: int, end_pos: int
) -> DeltaMessage | None:
"""Process regular content (non-tool-call text).
def _get_param_type(self, param_name: str) -> str:
"""Get parameter type based on tool configuration, defaults to string
Args: Args:
param_name: Parameter name current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
Returns: Returns:
Parameter type DeltaMessage with content if non-empty
""" """
if not self.tools or not self.current_function_name: if start_pos >= end_pos:
return "string" return None
for tool in self.tools: content = current_text[start_pos:end_pos]
if not hasattr(tool, "type") or not (
hasattr(tool, "function") and hasattr(tool.function, "name") # Check if we're between tool calls - skip whitespace
): if start_pos > 0:
continue # Check if text before start_pos ends with </tool_call>
text_before = current_text[:start_pos]
if ( if (
tool.type == "function" text_before.rstrip().endswith(self.tool_call_end_token)
and tool.function.name == self.current_function_name and content.strip() == ""
): ):
if not hasattr(tool.function, "parameters"): # We just ended a tool call, skip whitespace between tool calls
return "string" self._processed_length = end_pos
params = tool.function.parameters return None
if isinstance(params, dict) and "properties" in params:
properties = params["properties"]
if param_name in properties and isinstance(
properties[param_name], dict
):
return self.repair_param_type(
str(properties[param_name].get("type", "string"))
)
elif isinstance(params, dict) and param_name in params:
param_config = params[param_name]
if isinstance(param_config, dict):
return self.repair_param_type(
str(param_config.get("type", "string"))
)
break
return "string"
def repair_param_type(self, param_type: str) -> str: # Return content if non-empty
"""Repair unknown parameter types by treating them as string if content:
Args: self._processed_length = end_pos
param_type: Parameter type return DeltaMessage(content=content)
Returns: # Mark as processed even if empty
Repaired parameter type self._processed_length = end_pos
""" return None
if (
param_type in ["string", "str", "text", "varchar", "char", "enum"] def _flush_remaining_content(self, current_text: str) -> DeltaMessage | None:
or param_type.startswith("int") """Flush any remaining unprocessed content as regular content.
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
or param_type.startswith("num")
or param_type.startswith("float")
or param_type in ["boolean", "bool", "binary"]
or (
param_type in ["object", "array", "arr", "sequence"]
or param_type.startswith("dict")
or param_type.startswith("list")
)
):
return param_type
else:
return "string"
def _convert_param_value(self, param_value: str, param_type: str) -> Any:
"""Convert value based on parameter type
Args: Args:
param_value: Parameter value current_text: Current accumulated text
param_type: Parameter type
Returns: Used when EOS token is encountered to handle incomplete tool calls.
Converted value
""" """
if param_value.lower() == "null": if not self._has_unprocessed_content(current_text):
return None return None
param_type = param_type.strip().lower() remaining = current_text[self._processed_length :]
if param_type in ["string", "str", "text", "varchar", "char", "enum"]: if remaining:
return param_value self._processed_length = len(current_text)
elif ( return DeltaMessage(content=remaining)
param_type.startswith("int")
or param_type.startswith("uint") self._processed_length = len(current_text)
or param_type.startswith("long") return None
or param_type.startswith("short")
or param_type.startswith("unsigned") def _format_delta_result(self, deltas: list[DeltaMessage]) -> DeltaMessage | None:
): """Format delta result for return.
try:
return int(param_value) Merges all deltas into a single DeltaMessage.
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not an integer, degenerating to string.",
param_value,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value: float = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' is not a float, degenerating to string.",
param_value,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
return param_value == "true"
else:
return param_value
def _convert_for_json_streaming(self, converted_value: Any, param_type: str) -> str:
"""Convert converted_value based on
whether it's empty and if type is string
Args: Args:
converted_value: Converted value deltas: List of delta messages
param_type: Parameter type
Returns: Returns:
Converted string for streaming output - None if empty
- Single merged DeltaMessage with all content and tool_calls
""" """
# Check if value is empty, but exclude numeric 0 if not deltas:
if converted_value is None or converted_value == "": return None
return ""
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# String type, remove double quotes
return json.dumps(converted_value, ensure_ascii=False)[1:-1]
else:
# Non-string type, return complete JSON string
if not isinstance(converted_value, str):
return json.dumps(converted_value, ensure_ascii=False)
else:
return converted_value
def _reset_xml_parser_after_tool_call(self): if len(deltas) == 1:
""" return deltas[0]
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# recreate XML parser # Merge multiple deltas into one
self.parser = ParserCreate() merged_content_parts = []
self.setup_parser() merged_tool_calls = []
# Reset current tool_call state
if self.current_call_id:
self.last_completed_call_id = self.current_call_id
self.current_call_id = None
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
self.current_param_name = None
self.current_param_value = ""
self.current_param_value_converted = ""
self.current_param_is_first = False
self.should_emit_end_newline = False
self.start_quote_emitted = False
self.text_content_buffer = ""
# Reset preprocessing and deferred parsing state
self._pre_inside_parameter = False
self._pre_param_buffer = ""
self._pre_current_param_name = None
self.defer_current_parameter = False
self.deferred_param_raw_value = ""
@ToolParserManager.register_module("step3p5")
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py for delta in deltas:
self.prev_tool_call_arr: list[dict] = [] if delta.content:
self.streamed_args_for_tool: list[str] = [] merged_content_parts.append(delta.content)
if delta.tool_calls:
merged_tool_calls.extend(delta.tool_calls)
logger.info( # Create merged DeltaMessage
"vLLM Successfully import tool parser %s !", self.__class__.__name__ merged_content = "".join(merged_content_parts) if merged_content_parts else None
)
def extract_tool_calls( # Build kwargs - only include tool_calls if non-empty
self, kwargs: dict[str, Any] = {"content": merged_content}
model_output: str, if merged_tool_calls:
request: ChatCompletionRequest, kwargs["tool_calls"] = merged_tool_calls
) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls:
return ExtractedToolCallInformation(
tool_calls=[],
tools_called=False,
content=result.content,
)
else:
tool_calls = []
for tool_call in result.tool_calls:
if tool_call.function and tool_call.function.name:
tool_calls.append(
ToolCall(
id=tool_call.id,
type=tool_call.type,
function=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
)
# Update tool call tracking arrays for compatibility return DeltaMessage(**kwargs)
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays def _process_complete_tool_calls(
while len(self.prev_tool_call_arr) <= tool_index: self, current_text: str, start_pos: int, end_pos: int
self.prev_tool_call_arr.append({"name": "", "arguments": ""}) ) -> list[DeltaMessage] | None:
while len(self.streamed_args_for_tool) <= tool_index: """Process complete tool calls and convert to delta sequence.
self.streamed_args_for_tool.append("")
# Update tool call information Args:
self.prev_tool_call_arr[tool_index]["name"] = ( current_text: Current accumulated text
tool_call.function.name start_pos: Start position (should be at <tool_call>)
) end_pos: End position (after </tool_call>)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
# Update streamed arguments Returns:
if tool_call.function.arguments: List of DeltaMessage if successful, None otherwise
self.streamed_args_for_tool[tool_index] = ( """
tool_call.function.arguments try:
) # Extract text segment containing complete tool call(s)
text_to_parse = current_text[start_pos:end_pos]
return ExtractedToolCallInformation( # Parse using non-streaming method
tool_calls=tool_calls, result = self.extract_tool_calls_basic(
tools_called=len(tool_calls) > 0, text_to_parse, self.streaming_request
content=result.content,
) )
def extract_tool_calls_streaming( # Case 1: Successfully parsed tool calls
self, if result.tools_called and result.tool_calls:
previous_text: str, # Note: Due to _find_first_complete_tool_call_end, we typically
current_text: str, # process only one tool call at a time
delta_text: str, # but we can also process multiple tool calls below
previous_token_ids: Sequence[int], deltas = self._build_tool_call_deltas(result.tool_calls, text_to_parse)
current_token_ids: Sequence[int], self._update_state_after_tool_calls(result.tool_calls, end_pos)
delta_token_ids: Sequence[int], return deltas if deltas else None
request: ChatCompletionRequest,
) -> DeltaMessage | None: # Case 2: Parsing failed - treat as regular content
if not previous_text: self._processed_length = end_pos
self.parser.reset_streaming_state() return [DeltaMessage(content=text_to_parse)]
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = [] except Exception as e:
self.streamed_args_for_tool = [] # Exception during parsing - treat as content
if request: logger.debug("Failed to parse tool calls: %s, treating as content", e)
self.parser.set_tools(request.tools) self._processed_length = end_pos
failed_text = current_text[start_pos:end_pos]
# Model sometimes outputs separately causing delta_text to be empty. return [DeltaMessage(content=failed_text)] if failed_text else None
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output def _build_tool_call_deltas(
# to correctly output tool_call field self, tool_calls: list[ToolCall], parsed_text: str
if not delta_text and delta_token_ids: ) -> list[DeltaMessage]:
open_calls = current_text.count( """Build delta messages from parsed tool calls with interleaved content.
self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token)
if (
open_calls == 0
and self.parser.tool_call_index > 0
or not self.parser.tool_call_index
and current_text
):
return DeltaMessage(content="")
return None
# Parse the delta text and get the result Args:
result = self.parser.parse_single_streaming_chunks(delta_text) tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
# Update tool call tracking arrays based on incremental parsing results
if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays Returns:
while len(self.prev_tool_call_arr) <= tool_index: List of DeltaMessage with tool calls and content interleaved
self.prev_tool_call_arr.append({"name": "", "arguments": ""}) """
while len(self.streamed_args_for_tool) <= tool_index: # Extract content segments between tool calls
self.streamed_args_for_tool.append("") content_segments = self._extract_content_between_tool_calls_list(parsed_text)
# Update tool name if provided # Convert all tool calls to DeltaToolCall list
if tool_call.function.name: delta_tool_calls = self._convert_tool_calls_to_deltas(
self.prev_tool_call_arr[tool_index]["name"] = ( tool_calls, self._tool_call_index
tool_call.function.name )
)
# Update arguments incrementally # Merge all content segments into a single string
if tool_call.function.arguments is not None: merged_content = "".join(content_segments)
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool: # Return a single DeltaMessage with all tool calls and content
""" # Build kwargs - only include non-empty fields
Skip the remaining_call calculation in serving_chat kwargs: dict[str, Any] = {}
if merged_content:
kwargs["content"] = merged_content
if delta_tool_calls:
kwargs["tool_calls"] = delta_tool_calls
# Only return DeltaMessage if we have content or tool_calls
if kwargs:
return [DeltaMessage(**kwargs)]
else:
return []
def _update_state_after_tool_calls(
self, tool_calls: list[ToolCall], end_pos: int
) -> None:
"""Update internal state after processing tool calls.
Args:
tool_calls: List of processed tool calls
end_pos: End position in buffer
""" """
return False # Update processed position
self._processed_length = end_pos
# Update tool call index
self._tool_call_index += len(tool_calls)
# Update prev_tool_call_arr for finish_reason
self._update_prev_tool_call_state(tool_calls)
\ No newline at end of file
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from functools import partial
from typing import Any, NewType, TypeAlias, overload from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
...@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo ...@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def _get_kv_cache_groups_uniform_page_size( def _get_kv_cache_groups_uniform_page_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec], kv_cache_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]: ) -> list[KVCacheGroupSpec]:
""" """
...@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups. memory per block is the same for all groups.
Args: Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns: Returns:
The generated KVCacheGroupSpecs The generated KVCacheGroupSpecs
...@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better # is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30 # strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10). # sw, where the group size should be 10).
min_num_layers = min([len(layers) for layers in same_type_layers.values()]) min_num_layers = min([len(layers) for layers in same_type_layers.values()]) #12
group_size = min_num_layers group_size = min_num_layers
max_num_layers = max([len(layers) for layers in same_type_layers.values()]) max_num_layers = max([len(layers) for layers in same_type_layers.values()]) #36
if max_num_layers < min_num_layers * 1.25: if max_num_layers < min_num_layers * 1.25:
# If the number of layers is not much larger than the minimum number of layers, # If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding # use the maximum number of layers as the group size to avoid too many padding
...@@ -1050,19 +1053,28 @@ def _get_kv_cache_groups_uniform_page_size( ...@@ -1050,19 +1053,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers / len(layers) * 100, num_padding_layers / len(layers) * 100,
) )
num_groups = cdiv(len(layers), group_size) num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have # for support multi layer mtp, we need to
# - stage 0: full.0, sw.0, sw.1 # make all mtp layers in the same group
# - stage 1: full.1, sw.2, sw.3 if (
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) vllm_config.speculative_config is not None
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because and vllm_config.speculative_config.enable_multi_layers_mtp
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) ):
# and it will be padded to (full.0, padding), (sw.0, sw.1), for i in range(0, len(layers), group_size):
# (padding, padding) to ensure the number of layers in each group is grouped_layers.append(layers[i : i + group_size])
# the same and will cause memory waste. else:
# To avoid this, we assign layers[i::num_groups] to the i-th group # In PP case, say if we have
# instead of layers[i * group_size: (i + 1) * group_size] # - stage 0: full.0, sw.0, sw.1
for i in range(num_groups): # - stage 1: full.1, sw.2, sw.3
grouped_layers.append(layers[i::num_groups]) # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
...@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups( ...@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups(
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups) group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size( page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups] [group.kv_cache_spec for group in kv_cache_groups]
) )
...@@ -1247,8 +1258,10 @@ def get_kv_cache_groups( ...@@ -1247,8 +1258,10 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers # have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page # into groups with the same number of layers, and thus same total page
# size. # size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) # return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return _get_kv_cache_groups_uniform_page_size(
vllm_config=vllm_config, kv_cache_spec=kv_cache_spec
)
def generate_scheduler_kv_cache_config( def generate_scheduler_kv_cache_config(
kv_cache_configs: list[KVCacheConfig], kv_cache_configs: list[KVCacheConfig],
...@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len( ...@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len(
) )
def _project_kv_cache_groups_to_worker(
global_kv_cache_groups: list[KVCacheGroupSpec],
worker_spec: dict[str, KVCacheSpec],
) -> list[KVCacheGroupSpec]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
worker_layer_names = [
layer_name for layer_name in group.layer_names if layer_name in worker_spec
]
group_spec = group.kv_cache_spec
if worker_layer_names and isinstance(group_spec, UniformTypeKVCacheSpecs):
group_spec = UniformTypeKVCacheSpecs(
block_size=group_spec.block_size,
kv_cache_specs={
layer_name: group_spec.kv_cache_specs[layer_name]
for layer_name in worker_layer_names
},
)
projected_groups.append(KVCacheGroupSpec(worker_layer_names, group_spec))
return projected_groups
def get_kv_cache_configs( def get_kv_cache_configs(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]], kv_cache_specs: list[dict[str, KVCacheSpec]],
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Set as AbstractSet
from dataclasses import replace
from itertools import product from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -57,6 +61,11 @@ class CudagraphDispatcher: ...@@ -57,6 +61,11 @@ class CudagraphDispatcher:
) )
self.keys_initialized = False self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
...@@ -64,6 +73,9 @@ class CudagraphDispatcher: ...@@ -64,6 +73,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size.""" """Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes capture_sizes = self.compilation_config.cudagraph_capture_sizes
assert capture_sizes is not None, (
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip( for end, start in zip(
capture_sizes + [max_size + 1], capture_sizes + [max_size + 1],
...@@ -92,8 +104,33 @@ class CudagraphDispatcher: ...@@ -92,8 +104,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes." "Use values from cudagraph_capture_sizes."
) )
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor( def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor: ) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len uniform_decode_query_len = self.uniform_decode_query_len
...@@ -111,6 +148,7 @@ class CudagraphDispatcher: ...@@ -111,6 +148,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs, num_reqs=num_reqs,
uniform=uniform_decode, uniform=uniform_decode,
has_lora=has_lora, has_lora=has_lora,
num_active_loras=num_active_loras,
) )
def add_cudagraph_key( def add_cudagraph_key(
...@@ -143,18 +181,27 @@ class CudagraphDispatcher: ...@@ -143,18 +181,27 @@ class CudagraphDispatcher:
lora_cases = [True] lora_cases = [True]
else: else:
lora_cases = [False] lora_cases = [False]
# Get LoRA cases to capture
# lora_cases = self._get_lora_cases()
self.captured_lora_counts = [
lora_count for lora_count in lora_cases if lora_count
]
# Note: we create all valid keys for cudagraph here but do not # Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy # guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered. # capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product( assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when mixed mode is enabled."
)
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases self.compilation_config.cudagraph_capture_sizes, lora_cases
): ):
self.add_cudagraph_key( self.add_cudagraph_key(
cudagraph_mode.mixed_mode(), cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor( self._create_padded_batch_descriptor(
bs, False, has_lora bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(), ).relax_for_mixed_batch_cudagraphs(),
) )
...@@ -168,15 +215,20 @@ class CudagraphDispatcher: ...@@ -168,15 +215,20 @@ class CudagraphDispatcher:
uniform_decode_query_len uniform_decode_query_len
* self.vllm_config.scheduler_config.max_num_seqs * self.vllm_config.scheduler_config.max_num_seqs
) )
assert self.compilation_config.cudagraph_capture_sizes is not None, (
"Cudagraph capture sizes must be set when full mode is enabled."
)
cudagraph_capture_sizes_for_decode = [ cudagraph_capture_sizes_for_decode = [
x x
for x in self.compilation_config.cudagraph_capture_sizes for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len if x <= max_num_tokens and x >= uniform_decode_query_len
] ]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): for bs, num_active_loras in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key( self.add_cudagraph_key(
CUDAGraphMode.FULL, CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora), self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
) )
self.keys_initialized = True self.keys_initialized = True
...@@ -199,14 +251,19 @@ class CudagraphDispatcher: ...@@ -199,14 +251,19 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len). length is uniform_decode_query_len).
has_lora: Whether LoRA is active. has_lora: Whether LoRA is active.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
disable_full: If True, skip FULL cudagraph checks and disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs) cascade attention that are not supported by full cudagraphs)
""" """
# allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if ( if (
not self.keys_initialized not self.keys_initialized
or self.cudagraph_mode == CUDAGraphMode.NONE or self.cudagraph_mode == CUDAGraphMode.NONE
or num_tokens > self.compilation_config.max_cudagraph_capture_size or num_tokens > self.compilation_config.max_cudagraph_capture_size
# or allowed_modes <= {CUDAGraphMode.NONE}
): ):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import ast import ast
from dataclasses import replace from dataclasses import replace
from importlib.util import find_spec from importlib.util import find_spec
from typing import Any, cast
import numpy as np import numpy as np
import torch import torch
...@@ -37,17 +38,21 @@ from vllm.v1.attention.backends.tree_attn import ( ...@@ -37,17 +38,21 @@ from vllm.v1.attention.backends.tree_attn import (
) )
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata, SpecDecodeMetadata
from vllm.v1.spec_decode.utils import ( from vllm.v1.spec_decode.utils import (
extend_all_queries_by_N,
compute_new_slot_mapping,
copy_and_expand_eagle_inputs_kernel,
eagle_prepare_inputs_padded_kernel, eagle_prepare_inputs_padded_kernel,
eagle_prepare_next_token_padded_kernel, eagle_prepare_next_token_padded_kernel,
) )
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -75,11 +80,33 @@ class SpecDecodeBaseProposer: ...@@ -75,11 +80,33 @@ class SpecDecodeBaseProposer:
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.enable_multi_layers_mtp = self.speculative_config.enable_multi_layers_mtp
self.layer_num = 1
# Unifying eagle, draft model, and parallel drafting support
self.parallel_drafting: bool = self.speculative_config.parallel_drafting
self.extra_slots_per_request = (
1 if not self.parallel_drafting else self.num_speculative_tokens
)
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
1 if self.pass_hidden_states_to_model else 0
)
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
self.parallel_drafting_token_id: int = 0
self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
if self.parallel_drafting:
self._init_parallel_drafting_params()
self.use_local_argmax_reduction: bool = (
self.speculative_config.use_local_argmax_reduction
)
# The drafter can get longer sequences than the target model. # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = ( # self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size # vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
) # )
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens) self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because # We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's # the draft model's hidden size can be different from the target model's
...@@ -93,6 +120,9 @@ class SpecDecodeBaseProposer: ...@@ -93,6 +120,9 @@ class SpecDecodeBaseProposer:
vllm_config.model_config vllm_config.model_config
) )
self.draft_attn_groups: list[AttentionGroup] = []
self.kv_cache_gid: int = -1
self.attn_metadata_builder: AttentionMetadataBuilder | None = None self.attn_metadata_builder: AttentionMetadataBuilder | None = None
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = [] self.attn_layer_names: list[str] = []
...@@ -116,6 +146,8 @@ class SpecDecodeBaseProposer: ...@@ -116,6 +146,8 @@ class SpecDecodeBaseProposer:
# Use draft model's M-RoPE setting, not target model's # Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal # Draft models may be text-only even if target is multimodal
self.uses_mrope = self.draft_model_config.uses_mrope self.uses_mrope = self.draft_model_config.uses_mrope
self.uses_xdrope_dim = self.vllm_config.model_config.uses_xdrope_dim
self.draft_uses_xdrope_dim = self.draft_model_config.uses_xdrope_dim
if self.uses_mrope: if self.uses_mrope:
# NOTE: `mrope_positions` is implemented with one additional dummy # NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work # position on purpose to make it non-contiguous so that it can work
...@@ -139,6 +171,9 @@ class SpecDecodeBaseProposer: ...@@ -139,6 +171,9 @@ class SpecDecodeBaseProposer:
(self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device
) )
# Will be set when we initialize the attention backend
# self.block_size: int = -1
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
...@@ -146,6 +181,26 @@ class SpecDecodeBaseProposer: ...@@ -146,6 +181,26 @@ class SpecDecodeBaseProposer:
max_num_slots_for_arange, device=device, dtype=torch.int32 max_num_slots_for_arange, device=device, dtype=torch.int32
) )
if self.needs_extra_input_slots:
self._raise_if_padded_drafter_batch_disabled()
self._raise_if_multimodal()
self._raise_if_mrope()
self.is_rejected_token_mask: torch.Tensor | None = None
self.is_masked_token_mask: torch.Tensor | None = None
if self.needs_extra_input_slots:
# For draft models and parallel drafting, we need to keep track of
# which tokens are rejected to update the slot mapping with padding slots.
self.is_rejected_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
# For parallel drafting, we also need to keep track of which tokens
# are parallel-padding tokens used to sample at later positions.
# We populate this tensor even when using draft models for simplicity.
self.is_masked_token_mask = torch.zeros(
(self.max_num_tokens,), dtype=torch.bool, device=device
)
self.inputs_embeds = torch.zeros( self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.inputs_embeds_size), (self.max_num_tokens, self.inputs_embeds_size),
dtype=self.dtype, dtype=self.dtype,
...@@ -166,36 +221,6 @@ class SpecDecodeBaseProposer: ...@@ -166,36 +221,6 @@ class SpecDecodeBaseProposer:
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
# if current_platform.is_rocm():
# from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata
# rocm_types = [
# TritonAttentionMetadata,
# RocmAttentionMetadata,
# ]
# # ROCM_AITER_FA is an optional backend
# if find_spec(
# AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
# ):
# from vllm.v1.attention.backends.rocm_aiter_fa import (
# AiterFlashAttentionMetadata,
# )
# rocm_types.append(AiterFlashAttentionMetadata)
# # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
# from vllm.model_executor.layers.attention.mla_attention import (
# MLACommonMetadata,
# )
# rocm_types.append(MLACommonMetadata)
# # FlexAttention backend support
# from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
# rocm_types.append(FlexAttentionMetadata)
# self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree. # Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree spec_token_tree = self.speculative_config.speculative_token_tree
...@@ -251,7 +276,8 @@ class SpecDecodeBaseProposer: ...@@ -251,7 +276,8 @@ class SpecDecodeBaseProposer:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID) self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens] view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names} # return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return {name: view for name in self._draft_attn_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle. """Initialize cudagraph dispatcher keys for eagle.
...@@ -270,6 +296,23 @@ class SpecDecodeBaseProposer: ...@@ -270,6 +296,23 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
return (
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -280,9 +323,10 @@ class SpecDecodeBaseProposer: ...@@ -280,9 +323,10 @@ class SpecDecodeBaseProposer:
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
# [batch_size] # [batch_size]
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None, token_indices_to_sample: torch.Tensor | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor] slot_mappings: dict[str, torch.Tensor]
...@@ -298,12 +342,28 @@ class SpecDecodeBaseProposer: ...@@ -298,12 +342,28 @@ class SpecDecodeBaseProposer:
) )
assert target_hidden_states.shape[-1] == self.hidden_size assert target_hidden_states.shape[-1] == self.hidden_size
num_tokens, last_token_indices, common_attn_metadata = ( (
target_token_ids,
target_positions,
target_hidden_states,
common_attn_metadata,
) = self.adjust_input(
batch_size=batch_size,
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
)
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass( self.set_inputs_first_pass(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
target_positions=target_positions, target_positions=target_positions,
last_token_indices=last_token_indices, target_hidden_states=target_hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata, cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu, num_rejected_tokens_gpu=num_rejected_tokens_gpu,
) )
...@@ -355,76 +415,112 @@ class SpecDecodeBaseProposer: ...@@ -355,76 +415,112 @@ class SpecDecodeBaseProposer:
# hidden dims. E.g. large target model and small draft model. # hidden dims. E.g. large target model and small draft model.
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
if self.supports_mm_inputs: ###### step3.5-mtp3新增
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) draft_token_ids_list = []
for spec_step_idx in range(self.layer_num):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
self.inputs_embeds[:num_tokens] = self.model.embed_input_ids( self.inputs_embeds[:num_tokens] = self.model.embed_input_ids(
self.input_ids[:num_tokens], self.input_ids[:num_tokens],
multimodal_embeddings=mm_embeds, multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed, is_multimodal=is_mm_embed,
) )
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if self.pass_hidden_states_to_model:
model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else: else:
last_hidden_states, hidden_states = ret_hidden_states input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
sample_hidden_states = last_hidden_states[last_token_indices] model_kwargs = {
logits = self.model.compute_logits(sample_hidden_states) "input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"inputs_embeds": inputs_embeds,
}
if envs.VLLM_REJECT_SAMPLE_OPT: if self.pass_hidden_states_to_model:
draft_prob = logits.softmax(dim=-1, dtype=torch.float32) model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
if self.enable_multi_layers_mtp:
model_kwargs["spec_step_idx"] = spec_step_idx
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[token_indices_to_sample]
if self.enable_multi_layers_mtp:
logits = self.model.compute_logits(
sample_hidden_states, spec_step_idx=spec_step_idx
)
else:
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
if envs.VLLM_REJECT_SAMPLE_OPT: # Generate the remaining draft tokens.
return draft_token_ids.view(-1, 1), draft_prob.view(-1, 1, logits.shape[-1]) draft_token_ids_list.append(draft_token_ids)
return draft_token_ids.view(-1, 1) if spec_step_idx < self.layer_num - 1:
prev_token_ids = self.input_ids[:num_tokens].clone()
hidden_states = hidden_states[:num_tokens]
next_token_ids = draft_token_ids_list[-1].int()
num_tokens, token_indices_to_sample, common_attn_metadata = (
self.set_inputs_first_pass(
target_token_ids=prev_token_ids,
next_token_ids=next_token_ids,
target_positions=target_positions,
target_hidden_states=hidden_states,
token_indices_to_sample=token_indices_to_sample,
cad=common_attn_metadata,
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
)
)
# Early exit if all draft tokens are generated in one pass
if self.num_speculative_tokens == self.layer_num or self.parallel_drafting:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
##########################################################################
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, last_token_indices] positions = self.mrope_positions[:, token_indices_to_sample]
else: else:
positions = self.positions[last_token_indices] positions = self.positions[token_indices_to_sample]
if self.method in ( if self.method in (
"deepseek_mtp", "deepseek_mtp",
"ernie_mtp", "ernie_mtp",
"longcat_flash_mtp", "longcat_flash_mtp",
"pangu_ultra_moe_mtp", "pangu_ultra_moe_mtp",
"step3p5_mtp", # 新增
): ):
hidden_states = self.hidden_states[last_token_indices] hidden_states = self.hidden_states[token_indices_to_sample]
else: else:
hidden_states = hidden_states[last_token_indices] hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata): if isinstance(attn_metadata, TreeAttentionMetadata):
######
if self.enable_multi_layers_mtp:
raise NotImplementedError(
"Speculative Decoding with multi-layer MTP and tree attention "
"is not supported yet."
)
#####
# Draft using tree attention. # Draft using tree attention.
draft_token_ids_list = self.propose_tree( draft_token_ids_list = self.propose_tree(
batch_size=batch_size, batch_size=batch_size,
...@@ -437,32 +533,22 @@ class SpecDecodeBaseProposer: ...@@ -437,32 +533,22 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens] # [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1) return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1) # draft_token_ids = logits.argmax(dim=-1)
if self.allowed_attn_types is not None and not isinstance( if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types attn_metadata, self.allowed_attn_types
): ):
raise ValueError( raise ValueError(
f"Unsupported attention metadata type for speculative " f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: " "decoding with num_speculative_tokens > layer_num: "
f"{type(attn_metadata)}. Supported types are: " f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}" f"{self.allowed_attn_types}"
) )
# Generate the remaining draft tokens. cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = (
draft_token_ids_list = [draft_token_ids] self._determine_batch_execution_and_padding(batch_size)
batch_size_dp_padded, batch_size_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=batch_size, num_tokens_padded=batch_size
) )
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
batch_size_dp_padded
)
input_batch_size = batch_desc.num_tokens
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1 common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
...@@ -483,7 +569,7 @@ class SpecDecodeBaseProposer: ...@@ -483,7 +569,7 @@ class SpecDecodeBaseProposer:
if envs.VLLM_REJECT_SAMPLE_OPT: if envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs_list = [draft_prob] draft_probs_list = [draft_prob]
for token_index in range(self.num_speculative_tokens - 1): for token_index in range(self.num_speculative_tokens - self.layer_num):
# Update the inputs. # Update the inputs.
# cast to int32 is crucial when eagle model is compiled. # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default. # tensor.argmax() returns int64 by default.
...@@ -562,23 +648,9 @@ class SpecDecodeBaseProposer: ...@@ -562,23 +648,9 @@ class SpecDecodeBaseProposer:
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
) )
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1,
)
)
else:
draft_indexer_metadata = None
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self._set_positions(batch_size, clamped_positions) self._set_positions(batch_size, clamped_positions)
...@@ -641,25 +713,138 @@ class SpecDecodeBaseProposer: ...@@ -641,25 +713,138 @@ class SpecDecodeBaseProposer:
target_token_ids: torch.Tensor, target_token_ids: torch.Tensor,
next_token_ids: torch.Tensor, next_token_ids: torch.Tensor,
target_positions: torch.Tensor, target_positions: torch.Tensor,
last_token_indices: torch.Tensor | None, target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor | None,
cad: CommonAttentionMetadata, cad: CommonAttentionMetadata,
num_rejected_tokens_gpu: torch.Tensor | None, num_rejected_tokens_gpu: torch.Tensor | None,
) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]:
if last_token_indices is None: if not self.needs_extra_input_slots:
last_token_indices = cad.query_start_loc[1:] - 1 # Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if token_indices_to_sample is None:
token_indices_to_sample = cad.query_start_loc[1:] - 1
num_tokens = target_token_ids.shape[0]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[token_indices_to_sample] = next_token_ids
# copy inputs to buffer for cudagraph
if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0:
target_positions = target_positions[0]
self._set_positions(num_tokens, target_positions)
self.hidden_states[:num_tokens] = target_hidden_states
return num_tokens, token_indices_to_sample, cad
else:
assert self.is_rejected_token_mask is not None
assert self.is_masked_token_mask is not None
# 1.
# Call a custom triton kernel to copy input_ids and positions
# into the correct slots in the preallocated buffers self.input_ids,
# self.positions.
batch_size = cad.batch_size()
# Since we might have to copy a lot of data for prefills, we select the
# block size based on the max query length and limit to max 256 slots/block.
max_num_tokens_per_request = (
cad.max_query_len + self.net_num_new_slots_per_request
)
BLOCK_SIZE_TOKENS = min(
256, triton.next_power_of_2(max_num_tokens_per_request)
)
num_blocks = (
max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
) // BLOCK_SIZE_TOKENS
total_num_input_tokens = target_token_ids.shape[0]
total_num_output_tokens = total_num_input_tokens + (
self.net_num_new_slots_per_request * batch_size
)
token_indices_to_sample = torch.empty(
batch_size * self.extra_slots_per_request,
dtype=torch.int32,
device=self.device,
)
# Destination indices to write target_hidden_states into drafting buffer.
out_hidden_state_mapping = torch.empty(
total_num_input_tokens, dtype=torch.int32, device=self.device
)
# Kernel grid: one program per request (row)
grid = (batch_size, num_blocks)
query_start_loc = cad.query_start_loc
query_end_loc = cad.query_start_loc[1:] - 1
if num_rejected_tokens_gpu is not None:
query_end_loc = query_end_loc - num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel[grid](
# (Padded) Inputs from the target model
target_token_ids_ptr=target_token_ids,
target_positions_ptr=target_positions,
next_token_ids_ptr=next_token_ids, # sampled tokens, one per request
# Outputs to the drafting buffers
out_input_ids_ptr=self.input_ids,
out_positions_ptr=self.positions, # Doesn't support mrope for now
out_is_rejected_token_mask_ptr=self.is_rejected_token_mask,
out_is_masked_token_mask_ptr=self.is_masked_token_mask,
out_new_token_indices_ptr=token_indices_to_sample,
out_hidden_state_mapping_ptr=out_hidden_state_mapping,
# Input metadata
query_start_loc_ptr=query_start_loc,
query_end_loc_ptr=query_end_loc,
padding_token_id=0,
parallel_drafting_token_id=self.parallel_drafting_token_id,
# Sizing info
# Note that we can deduce batch_size for free from the grid size
total_input_tokens=total_num_input_tokens,
num_padding_slots_per_request=self.extra_slots_per_request,
shift_input_ids=self.pass_hidden_states_to_model,
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
)
if self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.hidden_states[out_hidden_state_mapping] = target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask = self.is_masked_token_mask[:total_num_output_tokens]
torch.where(
mask.unsqueeze(1),
self.parallel_drafting_hidden_state_tensor,
self.hidden_states[:total_num_output_tokens],
out=self.hidden_states[:total_num_output_tokens],
)
num_tokens = target_token_ids.shape[0] # 2.
# Shift the input ids by one token. # Recompute the slot mapping based on the new positions and
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] # rejection mask.
self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Use the first draft attention group's kv_cache_spec for block_size
# Replace the last token with the next token. # (all draft layers share the same kv-cache group)
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] assert len(self.draft_attn_groups) > 0
self.input_ids[last_token_indices] = next_token_ids block_size = self.draft_attn_groups[0].kv_cache_spec.block_size
new_slot_mapping = compute_new_slot_mapping(
cad=cad,
new_positions=self.positions[:total_num_output_tokens],
is_rejected_token_mask=self.is_rejected_token_mask[
:total_num_output_tokens
],
block_size=block_size,
num_new_tokens=self.net_num_new_slots_per_request,
max_model_len=self.max_model_len,
)
# copy inputs to buffer for cudagraph # 3. Update the common attention metadata with the new (meta)data
self._set_positions(num_tokens, target_positions) new_cad = extend_all_queries_by_N(
cad,
N=self.net_num_new_slots_per_request,
arange=self.arange,
new_slot_mapping=new_slot_mapping,
)
return num_tokens, last_token_indices, cad return total_num_output_tokens, token_indices_to_sample, new_cad
def model_returns_tuple(self) -> bool: def model_returns_tuple(self) -> bool:
return self.method not in ("mtp", "draft_model") return self.method not in ("mtp", "draft_model")
...@@ -1096,10 +1281,28 @@ class SpecDecodeBaseProposer: ...@@ -1096,10 +1281,28 @@ class SpecDecodeBaseProposer:
model = model.module model = model.module
return model.__class__.__name__ return model.__class__.__name__
def _get_model(self) -> nn.Module:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
model = get_model(
vllm_config=self.vllm_config,
model_config=self.speculative_config.draft_model_config,
# load_config=self.speculative_config.draft_load_config,
)
return model
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
draft_model_config = self.vllm_config.speculative_config.draft_model_config
target_attn_layer_names = set( target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
).keys()
) )
# FIXME: support hybrid kv for draft model # FIXME: support hybrid kv for draft model
target_indexer_layer_names = set( target_indexer_layer_names = set(
...@@ -1107,23 +1310,26 @@ class SpecDecodeBaseProposer: ...@@ -1107,23 +1310,26 @@ class SpecDecodeBaseProposer:
self.vllm_config, DeepseekV32IndexerCache self.vllm_config, DeepseekV32IndexerCache
).keys() ).keys()
) )
self.model = self._get_model()
from vllm.compilation.backends import set_model_tag
# Find draft layers (attention layers added by draft model)
with set_model_tag("eagle_head"): # all_attn_layers = get_layers_from_vllm_config(
self.model = get_model( # self.vllm_config,
vllm_config=self.vllm_config, model_config=draft_model_config # AttentionLayerBase, # type: ignore[type-abstract]
) # )
# self._draft_attn_layer_names = (
draft_attn_layer_names = ( # set(all_attn_layers.keys()) - target_attn_layer_names
# )
self._draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
- target_attn_layer_names - target_attn_layer_names
) )
indexer_layers = get_layers_from_vllm_config( indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache self.vllm_config, DeepseekV32IndexerCache
) )
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names) self.attn_layer_names = list(self._draft_attn_layer_names - draft_indexer_layer_names)
self.indexer_layer_names = list(draft_indexer_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names)
if self.indexer_layer_names: if self.indexer_layer_names:
...@@ -1160,6 +1366,7 @@ class SpecDecodeBaseProposer: ...@@ -1160,6 +1366,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration", "Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration",
"HunYuanVLForConditionalGeneration",
"GlmOcrForConditionalGeneration", "GlmOcrForConditionalGeneration",
"Qwen3_5ForConditionalGeneration", "Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration",
...@@ -1177,12 +1384,34 @@ class SpecDecodeBaseProposer: ...@@ -1177,12 +1384,34 @@ class SpecDecodeBaseProposer:
else: else:
target_language_model = target_model target_language_model = target_model
# share embed_tokens with the target model if needed self._maybe_share_embeddings(target_language_model)
self._maybe_share_lm_head(target_language_model)
if self.parallel_drafting and self.pass_hidden_states_to_model:
assert self.parallel_drafting_hidden_state_tensor is not None
self.parallel_drafting_hidden_state_tensor.copy_(
self.model.combine_hidden_states(
self.model.mask_hidden.view(3 * self.hidden_size)
)
if self.eagle3_use_aux_hidden_state
else self.model.mask_hidden.view(self.hidden_size)
)
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
if get_pp_group().world_size == 1: if get_pp_group().world_size == 1:
if hasattr(target_language_model.model, "embed_tokens"): inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = target_language_model.model.embed_tokens if inner_model is None:
elif hasattr(target_language_model.model, "embedding"): raise AttributeError("Target model does not have 'model' attribute")
target_embed_tokens = target_language_model.model.embedding if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
else: else:
raise AttributeError( raise AttributeError(
"Target model does not have 'embed_tokens' or 'embedding' attribute" "Target model does not have 'embed_tokens' or 'embedding' attribute"
...@@ -1237,7 +1466,12 @@ class SpecDecodeBaseProposer: ...@@ -1237,7 +1466,12 @@ class SpecDecodeBaseProposer:
" from the target model." " from the target model."
) )
# share lm_head with the target model if needed def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
"""
Some draft models may not have their own LM head, and some may have a
duplicate copy of the target model's LM head. In these cases, we share
the target model's LM head with the draft model to save memory.
"""
share_lm_head = False share_lm_head = False
if hasattr(self.model, "has_own_lm_head"): if hasattr(self.model, "has_own_lm_head"):
# EAGLE model # EAGLE model
...@@ -1299,6 +1533,32 @@ class SpecDecodeBaseProposer: ...@@ -1299,6 +1533,32 @@ class SpecDecodeBaseProposer:
"Shared target model lm_head with MTP shared_head.head." "Shared target model lm_head with MTP shared_head.head."
) )
if self.use_local_argmax_reduction:
if not hasattr(self.model, "get_top_tokens"):
raise ValueError(
"use_local_argmax_reduction is enabled but draft model "
f"{self.model.__class__.__name__} does not implement "
"get_top_tokens()."
)
# Warn if draft model has vocab remapping, which forces fallback
# to the full-logits path (negating the optimization).
if (
hasattr(self.model, "draft_id_to_target_id")
and self.model.draft_id_to_target_id is not None
):
logger.warning(
"use_local_argmax_reduction is enabled but draft model "
"uses draft_id_to_target_id vocab remapping. The "
"optimization will be bypassed (falling back to full "
"logits gather + argmax)."
)
else:
logger.info(
"Using local argmax reduction for draft token generation "
"(communication: O(2*tp_size) vs O(vocab_size))."
)
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
...@@ -1329,9 +1589,9 @@ class SpecDecodeBaseProposer: ...@@ -1329,9 +1589,9 @@ class SpecDecodeBaseProposer:
# Make sure to use EAGLE's own buffer during cudagraph capture. # Make sure to use EAGLE's own buffer during cudagraph capture.
if ( if (
self.attn_layer_names self._draft_attn_layer_names
and slot_mappings is not None and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings and next(iter(self._draft_attn_layer_names)) in slot_mappings
): ):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens) slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else: else:
...@@ -1425,6 +1685,64 @@ class SpecDecodeBaseProposer: ...@@ -1425,6 +1685,64 @@ class SpecDecodeBaseProposer:
== 1 == 1
), "All drafting layers should belong to the same kv cache group" ), "All drafting layers should belong to the same kv cache group"
# def initialize_attn_backend(
# self,
# kv_cache_config: KVCacheConfig,
# kernel_block_sizes: list[int] | None = None,
# ) -> None:
# """
# Initialize AttentionGroups for draft layers using kv_cache_config.
# Called from the model runner's initialize_metadata_builders.
# """
# all_attn_layers = get_layers_from_vllm_config(
# self.vllm_config,
# AttentionLayerBase, # type: ignore[type-abstract]
# )
# # Find which kv_cache_group the draft layers belong to
# self.validate_same_kv_cache_group(kv_cache_config)
# kv_cache_spec = None
# for gid, group in enumerate(kv_cache_config.kv_cache_groups):
# if self._draft_attn_layer_names & set(group.layer_names):
# self.kv_cache_gid = gid
# kv_cache_spec = group.kv_cache_spec
# break
# attention_groups: dict[tuple[str, str], AttentionGroup] = {}
# if kv_cache_spec is not None:
# for layer_name in self._draft_attn_layer_names:
# attn_backend = all_attn_layers[layer_name].get_attn_backend()
# backend_key = attn_backend.full_cls_name()
# if backend_key not in attention_groups:
# layer_kv_cache_spec = kv_cache_spec
# if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
# layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
# layer_name
# ]
# kernel_block_size = (
# kernel_block_sizes[self.kv_cache_gid]
# if kernel_block_sizes is not None
# and self.kv_cache_gid < len(kernel_block_sizes)
# else None
# )
# attn_group = AttentionGroup(
# backend=attn_backend,
# layer_names=[layer_name],
# kv_cache_spec=layer_kv_cache_spec,
# kv_cache_group_id=self.kv_cache_gid,
# )
# attn_group.create_metadata_builders(
# self.vllm_config,
# self.device,
# kernel_block_size=kernel_block_size,
# )
# attention_groups[backend_key] = attn_group
# else:
# attention_groups[backend_key].layer_names.append(layer_name)
# self.draft_attn_groups = list(attention_groups.values())
def _pad_batch_across_dp( def _pad_batch_across_dp(
self, self,
num_tokens_unpadded: int, num_tokens_unpadded: int,
...@@ -1449,6 +1767,50 @@ class SpecDecodeBaseProposer: ...@@ -1449,6 +1767,50 @@ class SpecDecodeBaseProposer:
return num_tokens_dp_padded, num_toks_across_dp return num_tokens_dp_padded, num_toks_across_dp
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
class EagleProposer(SpecDecodeBaseProposer): class EagleProposer(SpecDecodeBaseProposer):
def __init__( def __init__(
self, self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer import has_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.v1.attention.backend import AttentionMetadataBuilder, CommonAttentionMetadata
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
PADDING_SLOT_ID = -1
class ExtractHiddenStatesProposer:
def __init__(self, vllm_config: VllmConfig, device):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.num_speculative_tokens == 1
if vllm_config.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self.vllm_config = vllm_config
self.device = device
self.dtype = vllm_config.model_config.dtype
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self.model: nn.Module | None = None
self.attn_layer_names: list[str] = []
self.attn_metadata_builder: AttentionMetadataBuilder | None = None
# Maximum number of tokens for buffers
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
layer_ids = getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", None)
if not layer_ids:
raise ValueError(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self.num_hidden_states = len(layer_ids)
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.num_hidden_states, self.hidden_size),
dtype=self.dtype,
device=device,
)
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
def propose(
self,
sampled_token_ids: torch.Tensor,
target_hidden_states: list[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> tuple[torch.Tensor, KVConnectorOutput | None]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert self.model is not None and isinstance(target_hidden_states, list)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states = torch.stack(target_hidden_states, dim=1)
num_tokens = stacked_hidden_states.shape[0]
# Copy hidden states to buffer
self.hidden_states[:num_tokens] = stacked_hidden_states
assert self.attn_metadata_builder is not None
attn_metadata = self.attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(num_tokens)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
with (
set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
),
(
KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output)
if has_kv_transfer_group()
else nullcontext()
) as kv_connector_output,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return sampled_token_ids.unsqueeze(-1), kv_connector_output
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names}
def _determine_batch_execution_and_padding(
self,
num_tokens: int,
use_cudagraphs: bool = True,
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens,
valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None),
)
num_tokens_padded = batch_desc.num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch, num_tokens_across_dp = False, None
if self.vllm_config.parallel_config.data_parallel_size > 1:
should_ubatch, num_tokens_across_dp, synced_cudagraph_mode = (
coordinate_batch_across_dp(
num_tokens_unpadded=num_tokens,
parallel_config=self.vllm_config.parallel_config,
allow_microbatching=False,
num_tokens_padded=num_tokens_padded,
cudagraph_mode=cudagraph_mode.value,
)
)
assert not should_ubatch, (
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if num_tokens_across_dp is not None:
dp_rank = self.dp_rank
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_padded,
valid_modes={CUDAGraphMode(synced_cudagraph_mode)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert batch_desc.num_tokens == num_tokens_padded
num_tokens_across_dp[dp_rank] = num_tokens_padded
return cudagraph_mode, num_tokens_padded, num_tokens_across_dp
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert self.vllm_config.speculative_config is not None
if (
not self.vllm_config.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
proposer_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
proposer_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(proposer_cudagraph_mode)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
assert self.model is not None, "Model must be initialized before dummy_run"
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
self.model(
hidden_states=self.hidden_states[:num_input_tokens],
)
def _build_attn_metadata_builder(
self, draft_attn_layers: dict[str, AttentionLayerBase]
) -> AttentionMetadataBuilder:
"""Build the attention metadata builder from draft attention layers."""
if not draft_attn_layers:
raise ValueError("No attention layers found for ExtractHiddenStatesModel")
layer = next(iter(draft_attn_layers.values()))
attn_backend = layer.get_attn_backend()
return attn_backend.get_builder_cls()(
layer.get_kv_cache_spec(self.vllm_config),
self.attn_layer_names,
self.vllm_config,
self.device,
)
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs = gpu_input_batch.num_reqs
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
],
dtype=torch.int32,
device=device,
)
assert discard_request_mask.dtype == torch.bool
# With num_speculative_tokens == 1, there is exactly one token
sampled = sampled_token_ids[:, 0]
is_valid = (sampled >= 0) & (sampled < gpu_input_batch.vocab_size)
valid_sampled_tokens_count = is_valid.to(torch.int32)
use_sampled = is_valid & ~discard_request_mask[:num_reqs]
next_token_ids = torch.where(
use_sampled, sampled.to(torch.int32), backup_tokens_gpu
)
return next_token_ids, valid_sampled_tokens_count
def load_model(self, target_model: nn.Module) -> None:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys() # type: ignore[type-abstract]
)
assert self.vllm_config.speculative_config is not None
draft_model_config = self.vllm_config.speculative_config.draft_model_config
from vllm.compilation.backends import set_model_tag
with set_model_tag("extract_hidden_states"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
draft_attn_layers = {
name: layer
for name, layer in all_attn_layers.items()
if name not in target_attn_layer_names
}
self.attn_layer_names = list(draft_attn_layers.keys())
assert len(draft_attn_layers) == 1, (
"ExtractHiddenStatesModel should have exactly one "
f"attention layer, found {len(draft_attn_layers)}"
)
self.attn_metadata_builder = self._build_attn_metadata_builder(
draft_attn_layers
)
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert len(self.attn_layer_names) == 1
...@@ -67,3 +67,41 @@ class SpecDecodeMetadata: ...@@ -67,3 +67,41 @@ class SpecDecodeMetadata:
bonus_logits_indices=bonus_logits_indices, bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices, logits_indices=logits_indices,
) )
@dataclass
class MultiLayerEagleMetadata:
# [batch_size]
cached_len: torch.Tensor | None = None
# [batch_size, layer_num]
cached_token_ids: torch.Tensor | None = None
# [batch_size, layer_num, hidden_size]
cached_hidden_states: torch.Tensor | None = None
# [batch_size, layer_num]
cached_slot_mappings: torch.Tensor | None = None
# [batch_size, layer_num]
cached_positions: torch.Tensor | None = None
@classmethod
def make_dummy(
cls,
layer_num: int,
hidden_size: int,
device: torch.device,
) -> "MultiLayerEagleMetadata":
cached_len = torch.zeros((1), dtype=torch.int64, device=device)
cached_token_ids = torch.zeros((1, layer_num), dtype=torch.int32, device=device)
cached_hidden_states = torch.zeros(
(1, layer_num, hidden_size), dtype=torch.float32, device=device
)
cached_slot_mappings = torch.zeros(
(1, layer_num), dtype=torch.int64, device=device
)
cached_positions = torch.zeros((1, layer_num), dtype=torch.int64, device=device)
return cls(
cached_len=cached_len,
cached_token_ids=cached_token_ids,
cached_hidden_states=cached_hidden_states,
cached_slot_mappings=cached_slot_mappings,
cached_positions=cached_positions,
)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
CommonAttentionMetadata,
)
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import MultiLayerEagleMetadata
logger = init_logger(__name__)
BLOCK_HIDDEN = 128
BLOCK_TOKENS = 128
class MultiLayerEagleProposer(EagleProposer):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
super().__init__(vllm_config, device, runner)
self.layer_num: int = getattr(
self.speculative_config.draft_model_config.hf_text_config, "n_predict", 0
)
self.num_speculative_tokens: int = (
self.speculative_config.num_speculative_tokens
)
def adjust_input(
self,
batch_size: int,
target_token_ids: torch.Tensor,
target_positions: torch.Tensor,
target_hidden_states: torch.Tensor,
token_indices_to_sample: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
assert multi_layer_eagle_metadata is not None
if token_indices_to_sample is None:
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1
MAX_SHIFT = self.layer_num
assert MAX_SHIFT > 0
prev_token_ids = target_token_ids.clone()
prev_positions = target_positions.clone()
prev_hidden_states = target_hidden_states.clone()
slot_mapping = common_attn_metadata.slot_mapping
start_token_indices = common_attn_metadata.query_start_loc[:-1]
end_token_indices = common_attn_metadata.query_start_loc[1:] - 1
pos_for_shift = (
target_positions[0] if target_positions.dim() == 2 else target_positions
)
start_token_pos = pos_for_shift[start_token_indices]
shift = torch.minimum(
end_token_indices - token_indices_to_sample,
start_token_pos,
)
shift = torch.clamp(shift, min=0)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample.add_(shift)
common_attn_metadata.seq_lens.sub_(shift)
cached_lens = multi_layer_eagle_metadata.cached_len
shift = torch.minimum(shift, cached_lens)
_multi_layer_eagle_shift_and_cache(
batch_size=batch_size,
max_shift=MAX_SHIFT,
src_token_ids=target_token_ids,
dst_token_ids=prev_token_ids,
src_positions=target_positions,
dst_positions=prev_positions,
src_hidden_states=target_hidden_states,
dst_hidden_states=prev_hidden_states,
src_slot_mapping=slot_mapping,
dst_slot_mapping=slot_mapping,
start_token_indices=start_token_indices,
end_token_indices=end_token_indices,
token_indices_to_sample=token_indices_to_sample,
shift=shift,
cached_lens=cached_lens,
cached_prev_token_ids=multi_layer_eagle_metadata.cached_token_ids,
cached_prev_positions=multi_layer_eagle_metadata.cached_positions,
cached_prev_hidden_states=multi_layer_eagle_metadata.cached_hidden_states,
cached_slot_mappings=multi_layer_eagle_metadata.cached_slot_mappings,
common_attn_metadata=common_attn_metadata,
)
return prev_token_ids, prev_positions, prev_hidden_states, common_attn_metadata
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise Exception(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
use_cudagraphs: bool = True,
is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None:
cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = (
self._determine_batch_execution_and_padding(
num_tokens, use_cudagraphs=use_cudagraphs
)
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self._draft_attn_layer_names
and slot_mappings is not None
and next(iter(self._draft_attn_layer_names)) in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
adjust_input_kwargs = {
"batch_size": 1,
"target_token_ids": self.input_ids[:num_input_tokens],
"target_positions": self._get_positions(num_input_tokens),
"target_hidden_states": self.hidden_states[:num_input_tokens],
"token_indices_to_sample": torch.tensor(
[num_input_tokens - 1], dtype=torch.int32, device=self.device
),
"common_attn_metadata": CommonAttentionMetadata(
query_start_loc=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device=self.device
),
query_start_loc_cpu=torch.tensor(
[0, num_input_tokens], dtype=torch.int32, device="cpu"
),
seq_lens=torch.tensor(
[num_input_tokens], dtype=torch.int32, device=self.device
),
num_reqs=1,
num_actual_tokens=num_input_tokens,
max_query_len=num_input_tokens,
max_seq_len=self.max_model_len,
block_table_tensor=torch.tensor(
[], dtype=torch.int32, device=self.device
),
slot_mapping=self.arange[:num_input_tokens],
logits_indices_padded=None,
num_logits_indices=None,
causal=True,
encoder_seq_lens=None,
),
"multi_layer_eagle_metadata": MultiLayerEagleMetadata.make_dummy(
layer_num=self.layer_num,
hidden_size=self.hidden_size,
device=self.device,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self.adjust_input(**adjust_input_kwargs)
for fwd_idx in range(self.layer_num):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
else:
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
model_kwargs = {
"input_ids": input_ids,
"positions": self._get_positions(num_input_tokens),
"hidden_states": self.hidden_states[:num_input_tokens],
"inputs_embeds": inputs_embeds,
"spec_step_idx": fwd_idx,
}
self.model(**model_kwargs)
def _multi_layer_eagle_shift_and_cache(
*,
batch_size: int,
max_shift: int,
src_token_ids: torch.Tensor,
dst_token_ids: torch.Tensor,
src_positions: torch.Tensor,
dst_positions: torch.Tensor,
src_hidden_states: torch.Tensor,
dst_hidden_states: torch.Tensor,
src_slot_mapping: torch.Tensor,
dst_slot_mapping: torch.Tensor,
start_token_indices: torch.Tensor,
end_token_indices: torch.Tensor,
token_indices_to_sample: torch.Tensor,
shift: torch.Tensor,
cached_lens: torch.Tensor,
cached_prev_token_ids: torch.Tensor,
cached_prev_positions: torch.Tensor,
cached_prev_hidden_states: torch.Tensor,
cached_slot_mappings: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
):
if batch_size == 0:
return
assert max_shift > 0
assert cached_prev_positions.is_contiguous()
assert cached_prev_token_ids.is_contiguous()
assert cached_prev_hidden_states.is_contiguous()
assert cached_slot_mappings.is_contiguous()
assert src_hidden_states.is_contiguous()
assert dst_hidden_states.is_contiguous()
# If src/dst are the same tensor, shifting is unsafe without a separate src.
if src_slot_mapping.data_ptr() == dst_slot_mapping.data_ptr():
src_slot_mapping = src_slot_mapping.clone()
# Cache extraction for the next call.
store_start = torch.maximum(
start_token_indices,
(token_indices_to_sample + 1 - max_shift),
)
store_lens = torch.clamp(
token_indices_to_sample - store_start + 1,
min=0,
max=max_shift,
)
# Avoid device sync: query length == (end - start + 1) == diff of
# query_start_loc (CPU copy).
max_window_len = int(
(
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
.max()
.item()
)
num_blocks = max(1, (max_window_len + BLOCK_TOKENS - 1) // BLOCK_TOKENS)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_token_ids,
dst_token_ids,
cached_prev_token_ids,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_slot_mapping,
dst_slot_mapping,
cached_slot_mappings,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
_shift_and_gather_cache_1d_kernel[(batch_size, num_blocks)](
src_positions,
dst_positions,
cached_prev_positions,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
BLOCK_TOKENS=BLOCK_TOKENS,
)
hidden_size = int(dst_hidden_states.shape[1])
# Hidden blocking avoids extremely large Triton tiles (and huge cubins)
# when hidden_size is large.
num_hidden_blocks = max(1, (hidden_size + BLOCK_HIDDEN - 1) // BLOCK_HIDDEN)
_shift_and_gather_hidden_kernel[(batch_size, num_blocks, num_hidden_blocks)](
src_hidden_states,
dst_hidden_states,
cached_prev_hidden_states,
start_token_indices,
end_token_indices,
shift,
cached_lens,
store_start,
store_lens,
MAX_SHIFT=max_shift,
PADDED_SHIFT=triton.next_power_of_2(max_shift),
HIDDEN_SIZE=hidden_size,
BLOCK_TOKENS=BLOCK_TOKENS,
BLOCK_HIDDEN=BLOCK_HIDDEN,
num_warps=4,
)
cached_lens.copy_(store_lens)
return
@triton.jit
def _shift_and_gather_cache_1d_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# We operate on a packed batch where each sequence (request) occupies a
# contiguous window [start, end] (inclusive) in a flattened tensor.
# For the next speculative step, we build a right-shifted version of each
# window. The shift amount can differ per sequence.
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
#
# The vacated prefix is filled from a small per-sequence cache (up to
# MAX_SHIFT elements) that stores values from previous speculative steps.
#
# Example:
# cached_tail = [a3, a4]
# src_window = [b0, b1, b2, b3, b4]
# shift = 2
# -> dst_window = [a3, a4, b0, b1, b2]
#
# After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst
# (specified by store_start / store_len) so the next call can populate its
# prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
offs = base + k
dst_idx = start + offs
# get dst mask
window_len = end - start + 1
mask = offs < window_len
# load from cached
base_cached = cached_ptr + pid_seq * MAX_SHIFT
cached_idx = cached_len - shift + offs
cached_mask = offs < shift
val_cached = tl.load(base_cached + cached_idx, mask=mask & cached_mask, other=0)
# load from src
src_idx = start + offs - shift
val_src = tl.load(src_ptr + src_idx, mask=mask & ~cached_mask, other=0)
# store to dst
val = tl.where(cached_mask, val_cached, val_src)
tl.store(dst_ptr + dst_idx, val, mask=mask)
# Store into the per-sequence cache.
#
# Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the
# full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the
# cache contiguous.
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
store_mask = m < MAX_SHIFT
dst_idx = store_start + m
val = tl.load(dst_ptr + dst_idx, mask=store_mask & (m < store_len), other=0)
tl.store(base_cached + m, val, mask=store_mask)
@triton.jit
def _shift_and_gather_hidden_kernel(
src_ptr,
dst_ptr,
cached_ptr,
start_ptr,
end_ptr,
shift_ptr,
cached_len_ptr,
store_start_ptr,
store_len_ptr,
MAX_SHIFT: tl.constexpr,
PADDED_SHIFT: tl.constexpr,
HIDDEN_SIZE: tl.constexpr,
BLOCK_TOKENS: tl.constexpr,
BLOCK_HIDDEN: tl.constexpr,
):
# Per-sequence "shift + gather" for hidden states.
#
# This kernel implements the same logical transformation as
# _shift_and_gather_cache_1d_kernel, but operates on hidden states with
# shape [num_tokens, hidden_size].
#
# Layout:
# - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size]
# - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size]
#
# For each sequence window [start, end] (inclusive) and its shift value, for
# 0-based index i within the window:
# - Prefix (i < shift):
# dst[start + i, :] = cached[seq, cached_len - shift + i, :]
# - Body (i >= shift):
# dst[start + i, :] = src[start + i - shift, :]
#
# We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid
# extremely large Triton tiles when hidden_size is large. As in the 1D
# kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next
# call can populate its prefix from cache.
pid_seq = tl.program_id(0)
pid_blk = tl.program_id(1)
pid_hid = tl.program_id(2)
start = tl.load(start_ptr + pid_seq).to(tl.int32)
end = tl.load(end_ptr + pid_seq).to(tl.int32)
shift = tl.load(shift_ptr + pid_seq).to(tl.int32)
cached_len = tl.load(cached_len_ptr + pid_seq).to(tl.int32)
assert cached_len >= shift
# get dst indices
base = pid_blk * BLOCK_TOKENS
k = tl.arange(0, BLOCK_TOKENS)
tok_offs = base + k
dst_tok = start + tok_offs
n = pid_hid * BLOCK_HIDDEN + tl.arange(0, BLOCK_HIDDEN)
dst_ptrs = dst_ptr + dst_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
# get dst mask
window_len = end - start + 1
tok_mask = tok_offs < window_len
n_mask = n < HIDDEN_SIZE
mask = tok_mask[:, None] & n_mask[None, :]
# load from cached
base_cached = cached_ptr + pid_seq * HIDDEN_SIZE * MAX_SHIFT
cached_tok = cached_len - shift + tok_offs
cached_ptrs = base_cached + cached_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
cached_mask = tok_offs < shift
val_cached = tl.load(cached_ptrs, mask=mask & cached_mask[:, None], other=0)
# load from src
src_tok = start + tok_offs - shift
src_ptrs = src_ptr + src_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
val_src = tl.load(src_ptrs, mask=mask & ~cached_mask[:, None], other=0)
# store to dst
val = tl.where(cached_mask[:, None], val_cached, val_src)
tl.store(dst_ptrs, val, mask=mask)
# store to cached
store_start = tl.load(store_start_ptr + pid_seq).to(tl.int32)
store_len = tl.load(store_len_ptr + pid_seq).to(tl.int32)
m = tl.arange(0, PADDED_SHIFT)
m_mask = (m < MAX_SHIFT) & (m < store_len)
store_tok = store_start + m
dst_ptrs = dst_ptr + store_tok[:, None] * HIDDEN_SIZE + n[None, :] * 1
store_ptrs = base_cached + m[:, None] * HIDDEN_SIZE + n[None, :] * 1
mask = m_mask[:, None] & n_mask[None, :]
val = tl.load(dst_ptrs, mask=mask, other=0)
tl.store(store_ptrs, val, mask=mask)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
)
from vllm.forward_context import set_forward_context
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@support_torch_compile()
class NgramGPUKernel(nn.Module):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def __init__(
self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
):
super().__init__()
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
def _find_first_and_extract_all_n_parallel(
self,
token_ids: torch.Tensor,
seq_lengths: torch.Tensor,
min_ngram_len: int,
max_ngram_len: int,
num_draft_tokens: int,
) -> torch.Tensor:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size = token_ids.shape[0]
max_seq_len = token_ids.shape[1]
device = token_ids.device
num_ngram_sizes = max_ngram_len - min_ngram_len + 1
# All n-gram sizes to try.
ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
batch_indices = torch.arange(batch_size, device=device)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions = torch.full(
(batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
)
for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows = token_ids.unfold(1, ngram_len, 1)
num_windows = search_windows.shape[1]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts = seq_lengths - ngram_len
suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
ngram_len, device=device
)
suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
# Window matches for each sequence.
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
# Match must leave room for at least one draft token.
max_valid_suffix_start = seq_lengths - ngram_len - 1
window_positions = torch.arange(num_windows, device=device)
valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
final_matches = matches & valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx = torch.argmax(final_matches.int(), dim=1)
has_match = final_matches[batch_indices, first_match_idx]
# Store valid match positions (window index = position).
first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)
# Select the longest n-gram with a match.
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back
# Match position for the best n-gram.
best_match_pos = first_match_positions[batch_indices, best_ngram_idx]
# Avoid data-dependent branching.
has_any_match = best_match_pos >= 0
# Length of the best matching n-gram.
best_ngram_lengths = ngram_lengths[best_ngram_idx]
# Start position right after the matched suffix.
draft_start = torch.where(
has_any_match,
best_match_pos + best_ngram_lengths,
torch.zeros_like(best_match_pos),
)
tokens_available = seq_lengths - draft_start
# Gather indices for draft tokens.
draft_indices = draft_start.unsqueeze(1) + torch.arange(
num_draft_tokens, device=device
)
draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)
# Extract draft tokens; gather always runs.
draft_tokens = torch.gather(token_ids, 1, draft_indices)
# Mask positions beyond available tokens.
position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
valid_positions = position_indices < tokens_available.unsqueeze(1)
draft_tokens = torch.where(
valid_positions,
draft_tokens,
torch.full_like(draft_tokens, -1),
)
# If no match, mask all positions.
draft_tokens = torch.where(
has_any_match.unsqueeze(1),
draft_tokens,
torch.full_like(draft_tokens, -1),
)
return draft_tokens
def forward(
self,
num_tokens_no_spec: torch.Tensor,
token_ids_gpu: torch.Tensor,
combined_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device = token_ids_gpu.device
# Infer batch size to preserve dynamic shape.
actual_batch_size = token_ids_gpu.shape[0]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens = torch.full(
(actual_batch_size, self.k), -1, dtype=torch.int32, device=device
)
results = self._find_first_and_extract_all_n_parallel(
token_ids_gpu,
num_tokens_no_spec,
min_ngram_len=self.min_n,
max_ngram_len=self.max_n,
num_draft_tokens=self.k,
)
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
# Count leading contiguous valid (non -1) tokens per request.
is_valid = draft_tokens != -1 # [batch, k]
cum_valid = is_valid.int().cumsum(dim=1) # [batch, k]
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
return draft_tokens, num_valid_draft_tokens
def load_model(self, *args, **kwargs):
"""No model to load for N-gram proposer."""
pass
class NgramProposerGPU:
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["none"],
splitting_ops=[],
compile_sizes=[],
inductor_compile_config={
"enable_auto_functionalized_v2": False,
"max_autotune": True,
"aggressive_fusion": True,
"triton.autotune_pointwise": True,
"coordinate_descent_tuning": True,
"use_mixed_mm": False,
},
cudagraph_mode=CUDAGraphMode.NONE,
)
model_config = vllm_config.model_config
speculative_config = vllm_config.speculative_config
scheduler_config = vllm_config.scheduler_config
self.vllm_config = VllmConfig(
compilation_config=compilation_config,
model_config=model_config,
speculative_config=speculative_config,
scheduler_config=scheduler_config,
)
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
self.kernel = NgramGPUKernel(
vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
)
self.kernel.to(device)
self.kernel.eval()
self._dummy_run()
def _dummy_run(self):
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
batch_size=self.max_num_seqs,
max_seq_len=self.max_model_len,
pattern_len=self.k,
device=self.device,
)
combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)
for _ in range(3):
with set_forward_context(None, self.vllm_config):
_, _ = self.kernel(num_tokens, token_ids, combined_mask)
def _generate_dummy_data(
self,
batch_size: int,
max_seq_len: int,
pattern_len: int,
device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids = torch.zeros(
batch_size,
max_seq_len,
dtype=torch.int32,
device=device,
)
num_tokens = torch.randint(
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
)
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
return token_ids, num_tokens, sampled_flags, valid_mask
def propose(
self,
num_tokens_no_spec: torch.Tensor, # [batch_size]
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert token_ids_gpu.device == self.device
assert num_tokens_no_spec.device == self.device
batch_size = num_tokens_no_spec.shape[0]
max_seq_len = token_ids_gpu.shape[1]
max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets = torch.arange(max_new_tokens, device=self.device)
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
1
)
in_bounds = write_positions < max_seq_len
scatter_mask = (
valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
)
write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
existing_values = token_ids_gpu.gather(1, write_positions_long)
tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
tokens_to_scatter = torch.where(
scatter_mask,
tokens_cast,
existing_values,
)
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
# Compute validity masks.
sampled_flags = valid_sampled_tokens_count > 0
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
with set_forward_context(None, self.vllm_config):
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
draft_tokens, num_valid_draft_tokens = self.kernel(
num_tokens_tmp,
token_ids_gpu,
combined_mask,
)
return draft_tokens, num_valid_draft_tokens
def update_token_ids_ngram(
self,
sampled_token_ids: torch.Tensor | list[list[int]],
gpu_input_batch: InputBatch,
token_ids_gpu: torch.Tensor,
num_tokens_no_spec: torch.Tensor,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs = gpu_input_batch.num_reqs
if isinstance(sampled_token_ids, list):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len = max(
(len(sublist) for sublist in sampled_token_ids),
default=0,
)
# Ensure at least length 1 for tensor creation
max_len = max(max_len, 1)
padded_list = [
sublist + [-1] * (max_len - len(sublist))
for sublist in sampled_token_ids
]
sampled_token_ids = torch.tensor(
padded_list, dtype=torch.int32, device=self.device
)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
backup_next_token_ids = torch.gather(
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
).squeeze(1)
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)
# Mask valid tokens within each request.
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Rightmost valid index per row.
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Last valid token from each row; undefined if none.
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid; otherwise fallback to backup.
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
backup_next_token_ids,
)
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
def load_model(self, *args, **kwargs):
self.kernel.load_model(*args, **kwargs)
def update_scheduler_for_invalid_drafts(
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens_cpu: torch.Tensor,
scheduler_output: "SchedulerOutput",
req_id_to_index: dict[str, int],
) -> None:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data = scheduler_output.scheduled_cached_reqs
num_valid_draft_tokens_event.synchronize()
for req_id in req_data.req_ids:
req_index = req_id_to_index.get(req_id)
if req_index is None:
continue
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if spec_token_ids is None:
continue
scheduled_k = len(spec_token_ids)
valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
valid_k = max(0, min(valid_k, scheduled_k))
tokens_to_trim = scheduled_k - valid_k
scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim
if valid_k == 0:
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
else:
scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
:valid_k
]
def update_ngram_gpu_tensors_incremental(
input_batch: InputBatch,
token_ids_gpu_tensor: torch.Tensor,
num_tokens_no_spec_gpu: torch.Tensor,
new_reqs: list[CachedRequestState],
device: torch.device,
_pinned_idx_buf: torch.Tensor,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index = input_batch.prev_req_id_to_index
curr_req_id_to_index = input_batch.req_id_to_index
if not curr_req_id_to_index:
return
active_indices = list(curr_req_id_to_index.values())
n_active = len(active_indices)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu = _pinned_idx_buf[:n_active]
active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))
active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)
new_req_ids = {req.req_id for req in new_reqs}
# First run, no previous state.
if prev_req_id_to_index is None:
for idx in active_indices:
num_tokens = input_batch.num_tokens_no_spec[idx]
if num_tokens > 0:
token_ids_gpu_tensor[idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[idx, :num_tokens],
non_blocking=True,
)
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
return
# Detect index changes for reorder.
reorder_src: list[int] = []
reorder_dst: list[int] = []
for req_id, curr_idx in curr_req_id_to_index.items():
if req_id in new_req_ids:
continue
prev_idx = prev_req_id_to_index.get(req_id)
if prev_idx is not None and prev_idx != curr_idx:
reorder_src.append(prev_idx)
reorder_dst.append(curr_idx)
if reorder_src:
src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)
temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()
token_ids_gpu_tensor[dst_tensor] = temp_token_ids
num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens
# Full copy for new/resumed requests.
for req_state in new_reqs:
new_req_idx = curr_req_id_to_index.get(req_state.req_id)
if new_req_idx is None:
continue
num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
if num_tokens > 0:
token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
non_blocking=True,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
def _sync_num_tokens(
input_batch: InputBatch,
num_tokens_no_spec_gpu: torch.Tensor,
active_idx_cpu: torch.Tensor,
active_idx_gpu: torch.Tensor,
n_active: int,
device: torch.device,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
vals = _pinned_val_buf[:n_active]
vals.copy_(src_cpu.index_select(0, active_idx_cpu))
num_tokens_no_spec_gpu.index_copy_(
0,
active_idx_gpu,
vals.to(device=device, non_blocking=True),
)
def copy_num_valid_draft_tokens(
num_valid_draft_tokens_cpu: torch.Tensor,
num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens: torch.Tensor | None,
batch_size: int,
) -> None:
"""
Async D2H copy of per-request valid draft counts.
"""
if num_valid_draft_tokens is None:
return
num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
if num_reqs_to_copy <= 0:
return
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
)
num_valid_draft_tokens_event.record()
...@@ -5,7 +5,11 @@ import torch ...@@ -5,7 +5,11 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import async_tensor_h2d from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
)
PADDING_SLOT_ID = -1
@triton.jit @triton.jit
def eagle_prepare_inputs_padded_kernel( def eagle_prepare_inputs_padded_kernel(
...@@ -116,7 +120,7 @@ def eagle_prepare_next_token_padded_kernel( ...@@ -116,7 +120,7 @@ def eagle_prepare_next_token_padded_kernel(
class DraftProbs(ABC): # type: ignore[call-arg] class DraftProbs(ABC): # type: ignore[call-arg]
"""Draft probs corresponding to in-progress sequences.""" """Draft probs corresponding to in-progress sequences."""
# spec tokens probs. # spec tokens probs.
draft_probs: torch.Tensor draft_probs: torch.Tensor
# The request id list. # The request id list.
...@@ -182,3 +186,219 @@ class DraftProbs(ABC): # type: ignore[call-arg] ...@@ -182,3 +186,219 @@ class DraftProbs(ABC): # type: ignore[call-arg]
target_device=self.draft_probs.device, target_device=self.draft_probs.device,
pin_memory=True) pin_memory=True)
return self.draft_probs[index_tensor] return self.draft_probs[index_tensor]
def compute_new_slot_mapping(
cad: CommonAttentionMetadata,
new_positions: torch.Tensor,
is_rejected_token_mask: torch.Tensor,
block_size: int,
num_new_tokens: int,
max_model_len: int,
):
batch_size, n_blocks_per_req = cad.block_table_tensor.shape
req_indices = torch.arange(batch_size, device=cad.query_start_loc.device)
req_indices = torch.repeat_interleave(
req_indices,
cad.naive_query_lens() + num_new_tokens,
output_size=len(new_positions),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions = torch.clamp(new_positions, max=max_model_len - 1)
block_table_indices = (
req_indices * n_blocks_per_req + clamped_positions // block_size
)
block_nums = cad.block_table_tensor.view(-1)[block_table_indices]
block_offsets = clamped_positions % block_size
new_slot_mapping = block_nums * block_size + block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len = new_positions >= max_model_len
new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping.masked_fill_(is_rejected_token_mask, PADDING_SLOT_ID)
return new_slot_mapping
def extend_all_queries_by_N(
common_attn_metadata: CommonAttentionMetadata,
N: int,
arange: torch.Tensor,
new_slot_mapping: torch.Tensor,
) -> CommonAttentionMetadata:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad = common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc = cad.query_start_loc + N * arange[: len(cad.query_start_loc)]
new_query_start_loc_cpu = cad.query_start_loc_cpu + N * torch.arange(
len(cad.query_start_loc_cpu), dtype=torch.int32
)
new_cad = cad.replace(
query_start_loc=new_query_start_loc,
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens=cad.seq_lens + N,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens=cad.num_actual_tokens + cad.batch_size() * N,
# All query lens increase by N, so max query len increases by N
max_query_len=cad.max_query_len + N,
max_seq_len=cad.max_seq_len + N,
slot_mapping=new_slot_mapping,
)
return new_cad
# Unified copy/expand kernel
@triton.jit
def copy_and_expand_eagle_inputs_kernel(
# (Padded) Inputs from the target model
target_token_ids_ptr, # [total_tokens_in_batch]
target_positions_ptr, # [total_tokens_in_batch]
next_token_ids_ptr, # [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr, # [total_draft_tokens_in_batch] (output)
out_positions_ptr, # [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr, # [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr, # [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr, # [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr, # [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr, # [num_reqs]
padding_token_id, # tl.int32
parallel_drafting_token_id, # tl.int32
# Sizing info
total_input_tokens, # tl.int32
num_padding_slots_per_request, # tl.int32
shift_input_ids, # tl.bool
BLOCK_SIZE_TOKENS: tl.constexpr, # Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx = tl.program_id(axis=0)
token_batch_idx = tl.program_id(axis=1)
# Load query locations
query_start_loc = tl.load(query_start_loc_ptr + request_idx)
next_query_start_loc = tl.load(query_start_loc_ptr + request_idx + 1)
query_end_loc = tl.load(query_end_loc_ptr + request_idx)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if shift_input_ids:
num_valid_tokens = query_end_loc - query_start_loc
input_offset = 1
output_start = query_start_loc + request_idx * (
num_padding_slots_per_request - 1
)
else:
num_valid_tokens = query_end_loc - query_start_loc + 1
input_offset = 0
output_start = query_start_loc + request_idx * num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected = next_query_start_loc - query_end_loc - 1
# Total output tokens for this request
total_output_tokens = (
num_valid_tokens + num_padding_slots_per_request + num_rejected
)
# Process tokens in this block
j = token_batch_idx * BLOCK_SIZE_TOKENS + tl.arange(0, BLOCK_SIZE_TOKENS)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds = j < total_output_tokens
is_valid_region = j < num_valid_tokens
is_bonus_region = j == num_valid_tokens
is_parallel_draft_region = (j > num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
is_rejected_region = j >= num_valid_tokens + num_padding_slots_per_request
# Compute output indices
out_idx = output_start + j
# For valid tokens, compute input index
in_idx = query_start_loc + input_offset + j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped = tl.minimum(in_idx, total_input_tokens - 1)
# Load input tokens (masked to valid region)
token_ids = tl.load(
target_token_ids_ptr + in_idx_clamped, mask=is_valid_region & in_bounds, other=0
)
# Load the starting position for this request (first position in the sequence)
start_pos = tl.load(target_positions_ptr + query_start_loc)
# Load bonus token for this request
bonus_token = tl.load(next_token_ids_ptr + request_idx)
# Build final token_ids based on region
token_ids = tl.where(is_bonus_region, bonus_token, token_ids)
token_ids = tl.where(
is_parallel_draft_region, parallel_drafting_token_id, token_ids
)
token_ids = tl.where(is_rejected_region, padding_token_id, token_ids)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions = start_pos + j
# Rejected positions are don't-care, set to 0
positions = tl.where(is_rejected_region, 0, positions)
# Compute output masks
is_rejected_out = is_rejected_region & in_bounds
is_masked_out = is_parallel_draft_region & in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region = (j >= num_valid_tokens) & (
j < num_valid_tokens + num_padding_slots_per_request
)
new_token_local_idx = (
j - num_valid_tokens
) # 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx = (
request_idx * num_padding_slots_per_request + new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if shift_input_ids:
num_input_tokens_this_request = next_query_start_loc - query_start_loc
is_input_region = j < num_input_tokens_this_request
src_idx = query_start_loc + j
tl.store(out_hidden_state_mapping_ptr + src_idx, out_idx, mask=is_input_region)
# Store outputs
tl.store(out_input_ids_ptr + out_idx, token_ids, mask=in_bounds)
tl.store(out_positions_ptr + out_idx, positions, mask=in_bounds)
tl.store(out_is_rejected_token_mask_ptr + out_idx, is_rejected_out, mask=in_bounds)
tl.store(out_is_masked_token_mask_ptr + out_idx, is_masked_out, mask=in_bounds)
tl.store(
out_new_token_indices_ptr + new_token_out_idx,
out_idx,
mask=is_new_token_region & in_bounds,
)
\ No newline at end of file
...@@ -61,6 +61,13 @@ class CachedRequestState: ...@@ -61,6 +61,13 @@ class CachedRequestState:
pooling_params: PoolingParams | None = None pooling_params: PoolingParams | None = None
pooling_states: PoolingStates | None = None pooling_states: PoolingStates | None = None
# for multi layer eagle proposer
cached_len: torch.Tensor | None = None
cached_token_ids: torch.Tensor | None = None
cached_hidden_states: torch.Tensor | None = None
cached_slot_mappings: torch.Tensor | None = None
cached_positions: torch.Tensor | None = None
def __post_init__(self): def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds self.prompt_token_ids, self.prompt_embeds
...@@ -103,6 +110,8 @@ class InputBatch: ...@@ -103,6 +110,8 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
is_pooling_model: bool = False, is_pooling_model: bool = False,
cp_kv_cache_interleave_size: int = 1, cp_kv_cache_interleave_size: int = 1,
multi_layer_eagle_num: int = 0,
hidden_size: int | None = None,
): ):
ori_max_num_reqs = max_num_reqs ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT: if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
...@@ -223,7 +232,45 @@ class InputBatch: ...@@ -223,7 +232,45 @@ class InputBatch:
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
) )
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# Multi layer eagle
self.multi_layer_eagle_num = multi_layer_eagle_num
if multi_layer_eagle_num > 0:
self.cached_len = torch.zeros(
(max_num_reqs,), dtype=torch.int64, device=device
)
self.cached_token_ids = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int32,
device=device,
)
self.cached_hidden_states = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
hidden_size,
),
dtype=torch.float,
device=device,
)
self.cached_slot_mappings = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
self.cached_positions = torch.zeros(
(
max_num_reqs,
multi_layer_eagle_num,
),
dtype=torch.int64,
device=device,
)
# lora related # lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64) self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_request_ids: dict[int, set[str]] = {}
...@@ -464,6 +511,13 @@ class InputBatch: ...@@ -464,6 +511,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated. # Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1 self.num_accepted_tokens_cpu[req_index] = 1
if self.multi_layer_eagle_num > 0:
self.cached_len[req_index] = request.cached_len
self.cached_token_ids[req_index] = request.cached_token_ids
self.cached_hidden_states[req_index] = request.cached_hidden_states
self.cached_slot_mappings[req_index] = request.cached_slot_mappings
self.cached_positions[req_index] = request.cached_positions
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
lora_id = request.lora_request.lora_int_id lora_id = request.lora_request.lora_int_id
...@@ -662,6 +716,20 @@ class InputBatch: ...@@ -662,6 +716,20 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1], self.allowed_token_ids_mask_cpu_tensor[i1],
) )
if self.multi_layer_eagle_num > 0:
self.cached_len[i1], self.cached_len[i2] = (
self.cached_len[i2],
self.cached_len[i1],
)
self.cached_token_ids[[i1, i2], ...] = self.cached_token_ids[[i2, i1], ...]
self.cached_hidden_states[[i1, i2], ...] = self.cached_hidden_states[
[i2, i1], ...
]
self.cached_slot_mappings[[i1, i2], ...] = self.cached_slot_mappings[
[i2, i1], ...
]
self.cached_positions[[i1, i2], ...] = self.cached_positions[[i2, i1], ...]
def condense(self) -> None: def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices. """Slide non-empty requests down into lower, empty indices.
...@@ -784,6 +852,21 @@ class InputBatch: ...@@ -784,6 +852,21 @@ class InputBatch:
if bad_words_token_ids is not None: if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids self.bad_words_token_ids[empty_index] = bad_words_token_ids
if self.multi_layer_eagle_num > 0:
self.cached_len[empty_index] = self.cached_len[last_req_index]
self.cached_token_ids[empty_index] = self.cached_token_ids[
last_req_index
]
self.cached_hidden_states[empty_index] = self.cached_hidden_states[
last_req_index
]
self.cached_slot_mappings[empty_index] = self.cached_slot_mappings[
last_req_index
]
self.cached_positions[empty_index] = self.cached_positions[
last_req_index
]
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
...@@ -824,7 +907,7 @@ class InputBatch: ...@@ -824,7 +907,7 @@ class InputBatch:
if not self.all_greedy: if not self.all_greedy:
temperature = copy_slice( temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature, self.temperature_cpu_tensor, self.temperature,
num_reqs, repeat_counts num_reqs, repeat_counts
) )
else: else:
......
...@@ -149,8 +149,15 @@ from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler ...@@ -149,8 +149,15 @@ from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.draft_model import DraftModelProposer from vllm.v1.spec_decode.draft_model import DraftModelProposer
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata, MultiLayerEagleMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import (
copy_num_valid_draft_tokens,
# update_ngram_gpu_tensors_incremental,
# update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.multi_layer_eagle import MultiLayerEagleProposer
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.structured_output.utils import apply_grammar_bitmask
...@@ -316,6 +323,7 @@ class ExecuteModelState(NamedTuple): ...@@ -316,6 +323,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output: "SchedulerOutput" scheduler_output: "SchedulerOutput"
logits: torch.Tensor logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None spec_decode_metadata: SpecDecodeMetadata | None
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor sample_hidden_states: torch.Tensor
...@@ -336,6 +344,7 @@ class GPUModelRunner( ...@@ -336,6 +344,7 @@ class GPUModelRunner(
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
# self.offload_config = vllm_config.offload_config
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config self.load_config = vllm_config.load_config
...@@ -417,6 +426,9 @@ class GPUModelRunner( ...@@ -417,6 +426,9 @@ class GPUModelRunner(
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
# multi layer eagle
self.enable_multi_layer_eagle = False
self.eplb_state: EplbState | None = None self.eplb_state: EplbState | None = None
""" """
State of the expert parallelism load balancer. State of the expert parallelism load balancer.
...@@ -439,6 +451,9 @@ class GPUModelRunner( ...@@ -439,6 +451,9 @@ class GPUModelRunner(
self.encoder_cache: dict[str, torch.Tensor] = {} self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
self.multi_layer_eagle_num = 0
# Set up speculative decoding. # Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on # NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many # the last PP rank. This is not ideal if there are many
...@@ -450,6 +465,7 @@ class GPUModelRunner( ...@@ -450,6 +465,7 @@ class GPUModelRunner(
| EagleProposer | EagleProposer
| DraftModelProposer | DraftModelProposer
| MedusaProposer | MedusaProposer
| ExtractHiddenStatesProposer
) )
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config) self.drafter = NgramProposer(self.vllm_config)
...@@ -462,7 +478,19 @@ class GPUModelRunner( ...@@ -462,7 +478,19 @@ class GPUModelRunner(
elif self.speculative_config.method == "suffix": elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config) self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle(): elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self) if (
self.speculative_config.enable_multi_layers_mtp
and self.speculative_config.method == "mtp"
):
self.enable_multi_layer_eagle = True
self.drafter = MultiLayerEagleProposer(
self.vllm_config, self.device, self
)
self.multi_layer_eagle_num = self.drafter.layer_num
else:
self.drafter = EagleProposer(self.vllm_config, self.device, self)
# self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3": if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = ( self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state self.drafter.eagle3_use_aux_hidden_state
...@@ -471,12 +499,17 @@ class GPUModelRunner( ...@@ -471,12 +499,17 @@ class GPUModelRunner(
self.drafter = MedusaProposer( self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device vllm_config=self.vllm_config, device=self.device
) )
elif self.speculative_config.method == "extract_hidden_states":
self.drafter = ExtractHiddenStatesProposer(
vllm_config=self.vllm_config, device=self.device
)
self.use_aux_hidden_state_outputs = True
else: else:
raise ValueError( raise ValueError(
"Unknown speculative decoding method: " "Unknown speculative decoding method: "
f"{self.speculative_config.method}" f"{self.speculative_config.method}"
) )
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
self.rejection_sampler = RejectionSampler(self.sampler) self.rejection_sampler = RejectionSampler(self.sampler)
else: else:
...@@ -535,6 +568,10 @@ class GPUModelRunner( ...@@ -535,6 +568,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids=bool(custom_logitsprocs), logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
multi_layer_eagle_num=self.multi_layer_eagle_num
if self.enable_multi_layer_eagle
else 0,
hidden_size=self.model_config.get_hidden_size(),
) )
# Separate cuda stream for overlapping transfer of sampled token ids from # Separate cuda stream for overlapping transfer of sampled token ids from
...@@ -623,6 +660,7 @@ class GPUModelRunner( ...@@ -623,6 +660,7 @@ class GPUModelRunner(
(3, self.max_num_tokens + 1), dtype=torch.int64 (3, self.max_num_tokens + 1), dtype=torch.int64
) )
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL) # Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
# Similar to mrope but use assigned dimension number for RoPE, 4 as default. # Similar to mrope but use assigned dimension number for RoPE, 4 as default.
...@@ -805,7 +843,6 @@ class GPUModelRunner( ...@@ -805,7 +843,6 @@ class GPUModelRunner(
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
with_numpy=numpy, with_numpy=numpy,
) )
def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None: def _copy_mrope_positions_to_gpu(self, num_tokens: int) -> None:
if not self.uses_mrope: if not self.uses_mrope:
return return
...@@ -816,6 +853,7 @@ class GPUModelRunner( ...@@ -816,6 +853,7 @@ class GPUModelRunner(
non_blocking=True, non_blocking=True,
) )
return return
self.mrope_positions.gpu[:, :num_tokens].copy_( self.mrope_positions.gpu[:, :num_tokens].copy_(
self.mrope_positions.cpu[:, :num_tokens], self.mrope_positions.cpu[:, :num_tokens],
non_blocking=True, non_blocking=True,
...@@ -1014,6 +1052,9 @@ class GPUModelRunner( ...@@ -1014,6 +1052,9 @@ class GPUModelRunner(
if self.uses_xdrope_dim > 0: if self.uses_xdrope_dim > 0:
self._init_xdrope_positions(req_state) self._init_xdrope_positions(req_state)
if self.enable_multi_layer_eagle:
self._init_multi_layer_eagle_cache(req_state)
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
...@@ -1265,6 +1306,24 @@ class GPUModelRunner( ...@@ -1265,6 +1306,24 @@ class GPUModelRunner(
req_state.mm_features, req_state.mm_features,
) )
def _init_multi_layer_eagle_cache(self, req_state: CachedRequestState):
req_state.cached_len = torch.zeros(1, dtype=torch.int64, device=self.device)
req_state.cached_hidden_states = torch.zeros(
self.multi_layer_eagle_num,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device,
)
req_state.cached_token_ids = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int32, device=self.device
)
req_state.cached_positions = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
req_state.cached_slot_mappings = torch.zeros(
self.multi_layer_eagle_num, dtype=torch.int64, device=self.device
)
def _extract_mm_kwargs( def _extract_mm_kwargs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -1689,6 +1748,17 @@ class GPUModelRunner( ...@@ -1689,6 +1748,17 @@ class GPUModelRunner(
self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu() self.num_decode_draft_tokens.copy_to_gpu()
if self.enable_multi_layer_eagle:
multi_layer_eagle_metadata = MultiLayerEagleMetadata(
cached_len=self.input_batch.cached_len[:num_reqs],
cached_token_ids=self.input_batch.cached_token_ids[:num_reqs],
cached_hidden_states=self.input_batch.cached_hidden_states[:num_reqs],
cached_slot_mappings=self.input_batch.cached_slot_mappings[:num_reqs],
cached_positions=self.input_batch.cached_positions[:num_reqs],
)
else:
multi_layer_eagle_metadata = None
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
assert ( assert (
...@@ -1699,10 +1769,11 @@ class GPUModelRunner( ...@@ -1699,10 +1769,11 @@ class GPUModelRunner(
self.input_batch, num_scheduled_tokens, num_sampled_tokens self.input_batch, num_scheduled_tokens, num_sampled_tokens
) )
return ( # return (
logits_indices, # logits_indices,
spec_decode_metadata, # spec_decode_metadata,
) # )
return (logits_indices, spec_decode_metadata, multi_layer_eagle_metadata)
def _build_attention_metadata( def _build_attention_metadata(
self, self,
...@@ -2127,7 +2198,7 @@ class GPUModelRunner( ...@@ -2127,7 +2198,7 @@ class GPUModelRunner(
) )
xdrope_pos_ptr += completion_part_len xdrope_pos_ptr += completion_part_len
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
mrope_pos_ptr = 0 mrope_pos_ptr = 0
if self.use_1d_mrope: if self.use_1d_mrope:
...@@ -2168,9 +2239,9 @@ class GPUModelRunner( ...@@ -2168,9 +2239,9 @@ class GPUModelRunner(
req.mrope_positions[:, src_start:src_end].transpose(0, 1) req.mrope_positions[:, src_start:src_end].transpose(0, 1)
) )
else: else:
self.mrope_positions.cpu[:, dst_start:dst_end] = ( self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
req.mrope_positions[:, src_start:src_end] :, src_start:src_end
) ]
mrope_pos_ptr += prompt_part_len mrope_pos_ptr += prompt_part_len
if completion_part_len > 0: if completion_part_len > 0:
...@@ -2181,9 +2252,7 @@ class GPUModelRunner( ...@@ -2181,9 +2252,7 @@ class GPUModelRunner(
assert req.mrope_position_delta is not None assert req.mrope_position_delta is not None
if self.use_1d_mrope: if self.use_1d_mrope:
values = np.arange( values = np.arange(
req.mrope_position_delta req.mrope_position_delta + num_computed_tokens + prompt_part_len,
+ num_computed_tokens
+ prompt_part_len,
req.mrope_position_delta req.mrope_position_delta
+ num_computed_tokens + num_computed_tokens
+ prompt_part_len + prompt_part_len
...@@ -2279,7 +2348,7 @@ class GPUModelRunner( ...@@ -2279,7 +2348,7 @@ class GPUModelRunner(
fused_meta_data = cu_num_draft_tokens.tolist() + cu_num_sampled_tokens.tolist()\ fused_meta_data = cu_num_draft_tokens.tolist() + cu_num_sampled_tokens.tolist()\
+ logits_indices.tolist() + target_logits_indices.tolist() + bonus_logits_indices.tolist()\ + logits_indices.tolist() + target_logits_indices.tolist() + bonus_logits_indices.tolist()\
+ draft_token_indices.tolist() + draft_token_indices.tolist()
fused_meta_data_len = np.array([len(cu_num_draft_tokens), len(cu_num_sampled_tokens),\ fused_meta_data_len = np.array([len(cu_num_draft_tokens), len(cu_num_sampled_tokens),\
len(logits_indices), len(target_logits_indices),\ len(logits_indices), len(target_logits_indices),\
len(bonus_logits_indices), len(draft_token_indices)], dtype=np.int32) len(bonus_logits_indices), len(draft_token_indices)], dtype=np.int32)
...@@ -2287,7 +2356,7 @@ class GPUModelRunner( ...@@ -2287,7 +2356,7 @@ class GPUModelRunner(
fused_meta_data = torch.tensor( fused_meta_data = torch.tensor(
fused_meta_data, dtype=torch.int32, pin_memory=self.pin_memory fused_meta_data, dtype=torch.int32, pin_memory=self.pin_memory
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
cu_num_draft_tokens = fused_meta_data[:cu_fused_meta_data_len[0]] cu_num_draft_tokens = fused_meta_data[:cu_fused_meta_data_len[0]]
cu_num_sampled_tokens = fused_meta_data[cu_fused_meta_data_len[0]:cu_fused_meta_data_len[1]] cu_num_sampled_tokens = fused_meta_data[cu_fused_meta_data_len[0]:cu_fused_meta_data_len[1]]
logits_indices = fused_meta_data[cu_fused_meta_data_len[1]:cu_fused_meta_data_len[2]] logits_indices = fused_meta_data[cu_fused_meta_data_len[1]:cu_fused_meta_data_len[2]]
...@@ -2900,7 +2969,7 @@ class GPUModelRunner( ...@@ -2900,7 +2969,7 @@ class GPUModelRunner(
inputs_embeds = None inputs_embeds = None
model_kwargs = self._init_model_kwargs() model_kwargs = self._init_model_kwargs()
positions = self._get_positions(num_input_tokens) positions = self._get_positions(num_input_tokens)
if is_first_rank: if is_first_rank:
intermediate_tensors = None intermediate_tensors = None
...@@ -3457,9 +3526,15 @@ class GPUModelRunner( ...@@ -3457,9 +3526,15 @@ class GPUModelRunner(
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
logits_indices, spec_decode_metadata = self._prepare_inputs( # logits_indices, spec_decode_metadata = self._prepare_inputs(
scheduler_output, # scheduler_output,
num_scheduled_tokens_np, # num_scheduled_tokens_np,
# )
logits_indices, spec_decode_metadata, multi_layer_eagle_metadata = (
self._prepare_inputs(
scheduler_output,
num_scheduled_tokens_np,
)
) )
cascade_attn_prefix_lens = None cascade_attn_prefix_lens = None
...@@ -3683,6 +3758,7 @@ class GPUModelRunner( ...@@ -3683,6 +3758,7 @@ class GPUModelRunner(
scheduler_output, scheduler_output,
logits, logits,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
hidden_states, hidden_states,
sample_hidden_states, sample_hidden_states,
...@@ -3720,6 +3796,7 @@ class GPUModelRunner( ...@@ -3720,6 +3796,7 @@ class GPUModelRunner(
scheduler_output, scheduler_output,
logits, logits,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
hidden_states, hidden_states,
sample_hidden_states, sample_hidden_states,
...@@ -3759,6 +3836,7 @@ class GPUModelRunner( ...@@ -3759,6 +3836,7 @@ class GPUModelRunner(
sample_hidden_states, sample_hidden_states,
aux_hidden_states, aux_hidden_states,
spec_decode_metadata, spec_decode_metadata,
multi_layer_eagle_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
slot_mappings, slot_mappings,
) )
...@@ -3959,6 +4037,233 @@ class GPUModelRunner( ...@@ -3959,6 +4037,233 @@ class GPUModelRunner(
sampled_count_event.synchronize() sampled_count_event.synchronize()
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist() return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
# def propose_draft_token_ids(
# self,
# scheduler_output: "SchedulerOutput",
# sampled_token_ids: torch.Tensor | list[list[int]],
# sampling_metadata: SamplingMetadata,
# hidden_states: torch.Tensor,
# sample_hidden_states: torch.Tensor,
# aux_hidden_states: list[torch.Tensor] | None,
# spec_decode_metadata: SpecDecodeMetadata | None,
# # multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
# common_attn_metadata: CommonAttentionMetadata,
# slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
# ) -> list[list[int]] | torch.Tensor:
# num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# spec_config = self.speculative_config
# assert spec_config is not None
# if spec_config.method == "ngram":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, NgramProposer)
# draft_token_ids = self.drafter.propose(
# sampled_token_ids,
# self.input_batch.num_tokens_no_spec,
# self.input_batch.token_ids_cpu,
# slot_mappings=slot_mappings,
# )
# elif spec_config.method == "suffix":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, SuffixDecodingProposer)
# draft_token_ids = self.drafter.propose(
# self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
# )
# elif spec_config.method == "medusa":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, MedusaProposer)
# if sample_hidden_states.shape[0] == len(sampled_token_ids):
# # The input to the target model does not include draft tokens.
# hidden_states = sample_hidden_states
# else:
# indices = []
# offset = 0
# assert spec_decode_metadata is not None, (
# "No spec decode metadata for medusa"
# )
# for num_draft, tokens in zip(
# spec_decode_metadata.num_draft_tokens, sampled_token_ids
# ):
# indices.append(offset + len(tokens) - 1)
# offset += num_draft + 1
# indices = torch.tensor(indices, device=self.device)
# hidden_states = sample_hidden_states[indices]
# draft_token_ids = self.drafter.propose(
# target_hidden_states=hidden_states,
# sampling_metadata=sampling_metadata,
# slot_mappings=slot_mappings,
# )
# elif spec_config.uses_extract_hidden_states():
# assert isinstance(self.drafter, ExtractHiddenStatesProposer)
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor for "
# "extract_hidden_states method."
# )
# if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
# raise ValueError(
# "aux_hidden_states are required when using `extract_hidden_states`"
# )
# target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
# draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
# sampled_token_ids=sampled_token_ids,
# target_hidden_states=target_hidden_states,
# common_attn_metadata=common_attn_metadata,
# scheduler_output=scheduler_output,
# slot_mappings=slot_mappings,
# )
# # Combine KVConnectorOutputs or select the non-empty one
# if self.kv_connector_output and drafter_kv_connector_output:
# self.kv_connector_output = KVConnectorOutput.merge(
# self.kv_connector_output, drafter_kv_connector_output
# )
# else:
# self.kv_connector_output = (
# self.kv_connector_output or drafter_kv_connector_output
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# elif spec_config.use_eagle() or spec_config.uses_draft_model():
# assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# if spec_config.disable_padded_drafter_batch:
# # When padded-batch is disabled, the sampled_token_ids should be
# # the cpu-side list[list[int]] of valid sampled tokens for each
# # request, with invalid requests having empty lists.
# assert isinstance(sampled_token_ids, list), (
# "sampled_token_ids should be a python list when"
# "padded-batch is disabled."
# )
# next_token_ids = self.drafter.prepare_next_token_ids_cpu(
# sampled_token_ids,
# self.requests,
# self.input_batch,
# scheduler_output.num_scheduled_tokens,
# )
# else:
# # When using padded-batch, the sampled_token_ids should be
# # the gpu tensor of sampled tokens for each request, of shape
# # (num_reqs, num_spec_tokens + 1) with rejected tokens having
# # value -1.
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor when"
# "padded-batch is enabled."
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# num_rejected_tokens_gpu = None
# if spec_decode_metadata is None:
# token_indices_to_sample = None
# # input_ids can be None for multimodal models.
# target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# target_positions = self._get_positions(num_scheduled_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:num_scheduled_tokens]
# else:
# if spec_config.disable_padded_drafter_batch:
# token_indices_to_sample = None
# common_attn_metadata, token_indices = self.drafter.prepare_inputs(
# common_attn_metadata,
# sampled_token_ids,
# spec_decode_metadata.num_draft_tokens,
# )
# target_token_ids = self.input_ids.gpu[token_indices]
# target_positions = self._get_positions(token_indices)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[token_indices] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[token_indices]
# else:
# (
# common_attn_metadata,
# token_indices_to_sample,
# num_rejected_tokens_gpu,
# ) = self.drafter.prepare_inputs_padded(
# common_attn_metadata,
# spec_decode_metadata,
# valid_sampled_tokens_count,
# )
# total_num_tokens = common_attn_metadata.num_actual_tokens
# # When padding the batch, token_indices is just a range
# target_token_ids = self.input_ids.gpu[:total_num_tokens]
# target_positions = self._get_positions(total_num_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:total_num_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:total_num_tokens]
# # if self.supports_mm_inputs:
# if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
# mm_embed_inputs = self._gather_mm_embeddings(
# scheduler_output,
# shift_computed_tokens=1,
# )
# else:
# mm_embed_inputs = None
# draft_result = self.drafter.propose(
# target_token_ids=target_token_ids,
# target_positions=target_positions,
# target_hidden_states=target_hidden_states,
# next_token_ids=next_token_ids,
# token_indices_to_sample=token_indices_to_sample,
# sampling_metadata=sampling_metadata,
# common_attn_metadata=common_attn_metadata,
# mm_embed_inputs=mm_embed_inputs,
# num_rejected_tokens_gpu=num_rejected_tokens_gpu,
# slot_mappings=slot_mappings,
# # multi_layer_eagle_metadata=multi_layer_eagle_metadata,
# )
# if not envs.VLLM_REJECT_SAMPLE_OPT:
# draft_token_ids = draft_result
# else:
# draft_token_ids, draft_probs = draft_result
# if envs.VLLM_REJECT_SAMPLE_OPT:
# draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# if self.draft_probs is None:
# self.draft_probs = DraftProbs(
# draft_probs, draft_req_ids)
# else:
# self.draft_probs.update(draft_probs, draft_req_ids)
# return draft_token_ids
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -3968,6 +4273,7 @@ class GPUModelRunner( ...@@ -3968,6 +4273,7 @@ class GPUModelRunner(
sample_hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None, spec_decode_metadata: SpecDecodeMetadata | None,
multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor: ) -> list[list[int]] | torch.Tensor:
...@@ -3975,6 +4281,8 @@ class GPUModelRunner( ...@@ -3975,6 +4281,8 @@ class GPUModelRunner(
spec_config = self.speculative_config spec_config = self.speculative_config
assert spec_config is not None assert spec_config is not None
if spec_config.method == "ngram": if spec_config.method == "ngram":
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
...@@ -3983,6 +4291,15 @@ class GPUModelRunner( ...@@ -3983,6 +4291,15 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu, self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
) )
if isinstance(self.drafter, NgramProposer):
assert isinstance(sampled_token_ids, list), (
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
)
elif spec_config.method == "suffix": elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer) assert isinstance(self.drafter, SuffixDecodingProposer)
...@@ -4015,6 +4332,48 @@ class GPUModelRunner( ...@@ -4015,6 +4332,48 @@ class GPUModelRunner(
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
) )
elif spec_config.uses_extract_hidden_states():
assert isinstance(self.drafter, ExtractHiddenStatesProposer)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for "
"extract_hidden_states method."
)
if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
raise ValueError(
"aux_hidden_states are required when using `extract_hidden_states`"
)
target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
sampled_token_ids=sampled_token_ids,
target_hidden_states=target_hidden_states,
common_attn_metadata=common_attn_metadata,
scheduler_output=scheduler_output,
slot_mappings=slot_mappings,
)
# Combine KVConnectorOutputs or select the non-empty one
if self.kv_connector_output and drafter_kv_connector_output:
self.kv_connector_output = KVConnectorOutput.merge(
self.kv_connector_output, drafter_kv_connector_output
)
else:
self.kv_connector_output = (
self.kv_connector_output or drafter_kv_connector_output
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
elif spec_config.use_eagle() or spec_config.uses_draft_model(): elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
...@@ -4106,7 +4465,7 @@ class GPUModelRunner( ...@@ -4106,7 +4465,7 @@ class GPUModelRunner(
else: else:
target_hidden_states = hidden_states[:total_num_tokens] target_hidden_states = hidden_states[:total_num_tokens]
if self.supports_mm_inputs: if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
mm_embed_inputs = self._gather_mm_embeddings( mm_embed_inputs = self._gather_mm_embeddings(
scheduler_output, scheduler_output,
shift_computed_tokens=1, shift_computed_tokens=1,
...@@ -4119,28 +4478,16 @@ class GPUModelRunner( ...@@ -4119,28 +4478,16 @@ class GPUModelRunner(
target_positions=target_positions, target_positions=target_positions,
target_hidden_states=target_hidden_states, target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample, token_indices_to_sample=token_indices_to_sample,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu, num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
multi_layer_eagle_metadata=multi_layer_eagle_metadata,
) )
if not envs.VLLM_REJECT_SAMPLE_OPT: return draft_result
draft_token_ids = draft_result
else:
draft_token_ids, draft_probs = draft_result
if envs.VLLM_REJECT_SAMPLE_OPT:
draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
if self.draft_probs is None:
self.draft_probs = DraftProbs(
draft_probs, draft_req_ids)
else:
self.draft_probs.update(draft_probs, draft_req_ids)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"} allowed_config_names = {"load_config", "model_config"}
...@@ -4963,7 +5310,7 @@ class GPUModelRunner( ...@@ -4963,7 +5310,7 @@ class GPUModelRunner(
# draft_probs = torch.randn( # draft_probs = torch.randn(
# num_tokens, logits.shape[-1], device=self.device, # num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype) # dtype=logits.dtype)
if not envs.VLLM_REJECT_SAMPLE_OPT: if not envs.VLLM_REJECT_SAMPLE_OPT:
draft_probs = None draft_probs = None
else: else:
...@@ -5709,6 +6056,8 @@ class GPUModelRunner( ...@@ -5709,6 +6056,8 @@ class GPUModelRunner(
logitsprocs=self.input_batch.logitsprocs, logitsprocs=self.input_batch.logitsprocs,
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids, logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
multi_layer_eagle_num=self.multi_layer_eagle_num if self.enable_multi_layer_eagle else 0,
hidden_size=self.model_config.get_hidden_size(),
) )
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
...@@ -5869,10 +6218,10 @@ class GPUModelRunner( ...@@ -5869,10 +6218,10 @@ class GPUModelRunner(
value_stride_order.index(i) value_stride_order.index(i)
for i in range(len(value_stride_order)) for i in range(len(value_stride_order))
] ]
raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype) raw_tensor = kv_cache_raw_tensors[layer_name].view(dtype)
total_elements = raw_tensor.numel() total_elements = raw_tensor.numel()
key_elements = (key_cache_shape[0] * key_cache_shape[1] * key_elements = (key_cache_shape[0] * key_cache_shape[1] *
key_cache_shape[2] * key_cache_shape[3]) key_cache_shape[2] * key_cache_shape[3])
value_elements = (value_cache_shape[0] * value_cache_shape[1] * value_elements = (value_cache_shape[0] * value_cache_shape[1] *
value_cache_shape[2] * value_cache_shape[3]) value_cache_shape[2] * value_cache_shape[3])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment