Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcc9c9ea
"vllm/entrypoints/api_server.py" did not exist on "4298374265a4379a2bd378373c7252b7a7b2b34f"
Commit
fcc9c9ea
authored
Mar 25, 2026
by
luopl
Browse files
feat:新增step3.5-mtp3功能
parent
9dc40d38
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4484 additions
and
1645 deletions
+4484
-1645
vllm/config/speculative.py
vllm/config/speculative.py
+55
-1
vllm/entrypoints/anthropic/protocol.py
vllm/entrypoints/anthropic/protocol.py
+29
-0
vllm/entrypoints/anthropic/serving.py
vllm/entrypoints/anthropic/serving.py
+366
-69
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+32
-10
vllm/forward_context.py
vllm/forward_context.py
+9
-1
vllm/lora/utils.py
vllm/lora/utils.py
+17
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-0
vllm/model_executor/models/step3p5_mtp.py
vllm/model_executor/models/step3p5_mtp.py
+17
-7
vllm/tool_parsers/abstract_tool_parser.py
vllm/tool_parsers/abstract_tool_parser.py
+5
-0
vllm/tool_parsers/step3p5_tool_parser.py
vllm/tool_parsers/step3p5_tool_parser.py
+997
-1341
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+67
-18
vllm/v1/cudagraph_dispatcher.py
vllm/v1/cudagraph_dispatcher.py
+62
-5
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+512
-150
vllm/v1/spec_decode/extract_hidden_states.py
vllm/v1/spec_decode/extract_hidden_states.py
+395
-0
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+38
-0
vllm/v1/spec_decode/multi_layer_eagle.py
vllm/v1/spec_decode/multi_layer_eagle.py
+526
-0
vllm/v1/spec_decode/ngram_proposer_gpu.py
vllm/v1/spec_decode/ngram_proposer_gpu.py
+660
-0
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+221
-1
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+85
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+389
-40
No files found.
vllm/config/speculative.py
View file @
fcc9c9ea
...
...
@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, model_validator
from
pydantic.dataclasses
import
dataclass
from
typing_extensions
import
Self
from
vllm.config
import
LoadConfig
from
vllm.config.model
import
ModelConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.config.utils
import
config
...
...
@@ -76,6 +77,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp
:
bool
=
False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
...
...
@@ -110,6 +115,11 @@ class SpeculativeConfig:
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
use_local_argmax_reduction
:
bool
=
False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration
prompt_lookup_max
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""Maximum size of ngram token window when using Ngram proposer, required
...
...
@@ -121,6 +131,12 @@ class SpeculativeConfig:
speculative_token_tree
:
str
|
None
=
None
"""Specifies the tree structure for speculative token generation.
"""
parallel_drafting
:
bool
=
False
"""Enable parallel drafting, where all speculative tokens are generated
in parallel rather than sequentially. This can improve performance but
requires the speculative model be trained to support parallel drafting.
Only compatible with EAGLE and draft model methods."""
# required configuration params passed from engine
target_model_config
:
SkipValidation
[
ModelConfig
]
=
None
# type: ignore
"""The configuration of the target model."""
...
...
@@ -154,6 +170,10 @@ class SpeculativeConfig:
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
draft_load_config
:
LoadConfig
|
None
=
None
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
def
compute_hash
(
self
)
->
str
:
"""
WARNING: Whenever a new field is added to this config,
...
...
@@ -401,7 +421,11 @@ class SpeculativeConfig:
MTPModelTypes
):
self
.
method
=
"mtp"
if
self
.
num_speculative_tokens
>
1
:
# if self.num_speculative_tokens > 1:
if
(
self
.
enable_multi_layers_mtp
is
False
and
self
.
num_speculative_tokens
>
1
):
logger
.
warning
(
"Enabling num_speculative_tokens > 1 will run"
"multiple times of forward on same MTP layer"
...
...
@@ -472,6 +496,17 @@ class SpeculativeConfig:
if
self
.
num_speculative_tokens
is
None
:
# Default to max value defined in draft model config.
self
.
num_speculative_tokens
=
n_predict
elif
(
self
.
method
==
"mtp"
and
self
.
enable_multi_layers_mtp
and
self
.
num_speculative_tokens
>
n_predict
):
logger
.
warning_once
(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self
.
num_speculative_tokens
=
n_predict
elif
(
self
.
num_speculative_tokens
>
n_predict
and
self
.
num_speculative_tokens
%
n_predict
!=
0
...
...
@@ -713,12 +748,31 @@ class SpeculativeConfig:
f
"errors during speculative decoding."
)
@
property
def
max_num_new_slots_for_drafting
(
self
)
->
int
:
"""
Calculate the maximum number of new slots that might be added to the batch
when drafting.
"""
slots_per_req
=
0
# for serial non-draft-model methods, no change needed
if
self
.
parallel_drafting
:
# For parallel drafting, we need one new slot per 'masked' token
slots_per_req
=
self
.
num_speculative_tokens
-
1
if
self
.
uses_draft_model
():
# For draft model-based speculation, we need one new slot per request
# Since we do not slice the draft tokens
slots_per_req
+=
1
return
slots_per_req
def
use_eagle
(
self
)
->
bool
:
return
self
.
method
in
(
"eagle"
,
"eagle3"
,
"mtp"
)
def
uses_draft_model
(
self
)
->
bool
:
return
self
.
method
==
"draft_model"
def
uses_extract_hidden_states
(
self
)
->
bool
:
return
self
.
method
==
"extract_hidden_states"
def
__repr__
(
self
)
->
str
:
method
=
self
.
method
model
=
None
if
method
in
(
"ngram"
,
"suffix"
)
else
self
.
draft_model_config
.
model
...
...
vllm/entrypoints/anthropic/protocol.py
View file @
fcc9c9ea
...
...
@@ -160,3 +160,32 @@ class AnthropicMessagesResponse(BaseModel):
def
model_post_init
(
self
,
__context
):
if
not
self
.
id
:
self
.
id
=
f
"msg_
{
int
(
time
.
time
()
*
1000
)
}
"
class
AnthropicContextManagement
(
BaseModel
):
"""Context management information for token counting."""
original_input_tokens
:
int
class
AnthropicCountTokensRequest
(
BaseModel
):
"""Anthropic messages.count_tokens request"""
model
:
str
messages
:
list
[
AnthropicMessage
]
system
:
str
|
list
[
AnthropicContentBlock
]
|
None
=
None
tool_choice
:
AnthropicToolChoice
|
None
=
None
tools
:
list
[
AnthropicTool
]
|
None
=
None
@
field_validator
(
"model"
)
@
classmethod
def
validate_model
(
cls
,
v
):
if
not
v
:
raise
ValueError
(
"Model is required"
)
return
v
class
AnthropicCountTokensResponse
(
BaseModel
):
"""Anthropic messages.count_tokens response"""
input_tokens
:
int
context_management
:
AnthropicContextManagement
|
None
=
None
\ No newline at end of file
vllm/entrypoints/anthropic/serving.py
View file @
fcc9c9ea
...
...
@@ -15,6 +15,9 @@ from fastapi import Request
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.anthropic.protocol
import
(
AnthropicContextManagement
,
AnthropicCountTokensRequest
,
AnthropicCountTokensResponse
,
AnthropicContentBlock
,
AnthropicDelta
,
AnthropicError
,
...
...
@@ -112,6 +115,7 @@ class AnthropicServingMessages(OpenAIServingChat):
# Handle complex content blocks
content_parts
:
list
[
dict
[
str
,
Any
]]
=
[]
tool_calls
:
list
[
dict
[
str
,
Any
]]
=
[]
reasoning_parts
:
list
[
str
]
=
[]
for
block
in
msg
.
content
:
if
block
.
type
==
"text"
and
block
.
text
:
...
...
@@ -123,6 +127,8 @@ class AnthropicServingMessages(OpenAIServingChat):
"image_url"
:
{
"url"
:
block
.
source
.
get
(
"data"
,
""
)},
}
)
elif
block
.
type
==
"thinking"
and
block
.
thinking
is
not
None
:
reasoning_parts
.
append
(
block
.
thinking
)
elif
block
.
type
==
"tool_use"
:
# Convert tool use to function call format
tool_call
=
{
...
...
@@ -157,6 +163,9 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
if
reasoning_parts
:
openai_msg
[
"reasoning"
]
=
""
.
join
(
reasoning_parts
)
# Add tool calls to the message if any
if
tool_calls
:
openai_msg
[
"tool_calls"
]
=
tool_calls
# type: ignore
...
...
@@ -297,10 +306,116 @@ class AnthropicServingMessages(OpenAIServingChat):
generator
:
AsyncGenerator
[
str
,
None
],
)
->
AsyncGenerator
[
str
,
None
]:
try
:
class
_ActiveBlockState
:
def
__init__
(
self
)
->
None
:
self
.
content_block_index
=
0
self
.
block_type
:
str
|
None
=
None
self
.
block_index
:
int
|
None
=
None
self
.
block_signature
:
str
|
None
=
None
self
.
signature_emitted
:
bool
=
False
self
.
tool_use_id
:
str
|
None
=
None
def
reset
(
self
)
->
None
:
self
.
block_type
=
None
self
.
block_index
=
None
self
.
block_signature
=
None
self
.
signature_emitted
=
False
self
.
tool_use_id
=
None
def
start
(
self
,
block
:
AnthropicContentBlock
)
->
None
:
self
.
block_type
=
block
.
type
self
.
block_index
=
self
.
content_block_index
if
block
.
type
==
"thinking"
:
self
.
block_signature
=
uuid
.
uuid4
().
hex
self
.
signature_emitted
=
False
self
.
tool_use_id
=
None
elif
block
.
type
==
"tool_use"
:
self
.
block_signature
=
None
self
.
signature_emitted
=
True
self
.
tool_use_id
=
block
.
id
else
:
self
.
block_signature
=
None
self
.
signature_emitted
=
True
self
.
tool_use_id
=
None
first_item
=
True
finish_reason
=
None
# content_block_index = 0
# content_block_started = False
content_block_index
=
0
content_block_started
=
False
active_block_type
:
str
|
None
=
None
active_block_index
:
int
|
None
=
None
active_block_signature
:
str
|
None
=
None
signature_emitted
=
False
active_tool_use_id
:
str
|
None
=
None
# Map from tool call index to tool_use_id
tool_index_to_id
:
dict
[
int
,
str
]
=
{}
def
stop_active_block
():
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
events
:
list
[
str
]
=
[]
if
active_block_type
is
None
:
return
events
if
(
active_block_type
==
"thinking"
and
active_block_signature
is
not
None
and
not
signature_emitted
):
chunk
=
AnthropicStreamEvent
(
index
=
active_block_index
,
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"signature_delta"
,
signature
=
active_block_signature
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_delta"
))
signature_emitted
=
True
stop_chunk
=
AnthropicStreamEvent
(
index
=
active_block_index
,
type
=
"content_block_stop"
,
)
data
=
stop_chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_stop"
))
active_block_type
=
None
active_block_index
=
None
active_block_signature
=
None
signature_emitted
=
False
active_tool_use_id
=
None
content_block_index
+=
1
return
events
def
start_block
(
block
:
AnthropicContentBlock
):
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
type
=
"content_block_start"
,
content_block
=
block
,
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
event
=
wrap_data_with_event
(
data
,
"content_block_start"
)
active_block_type
=
block
.
type
active_block_index
=
content_block_index
if
block
.
type
==
"thinking"
:
active_block_signature
=
uuid
.
uuid4
().
hex
signature_emitted
=
False
active_tool_use_id
=
None
elif
block
.
type
==
"tool_use"
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
block
.
id
else
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
None
return
event
async
for
item
in
generator
:
if
item
.
startswith
(
"data:"
):
...
...
@@ -326,6 +441,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id
=
origin_chunk
.
id
,
content
=
[],
model
=
origin_chunk
.
model
,
stop_reason
=
None
,
stop_sequence
=
None
,
usage
=
AnthropicUsage
(
input_tokens
=
origin_chunk
.
usage
.
prompt_tokens
if
origin_chunk
.
usage
...
...
@@ -341,13 +458,33 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info
if
len
(
origin_chunk
.
choices
)
==
0
:
if
content_block_started
:
stop_chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
type
=
"content_block_stop"
,
)
data
=
stop_chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_stop"
)
# if content_block_started:
# stop_chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_stop",
# )
# data = stop_chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_stop")
# stop_reason = self.stop_reason_map.get(
# finish_reason or "stop"
# )
# chunk = AnthropicStreamEvent(
# type="message_delta",
# delta=AnthropicDelta(stop_reason=stop_reason),
# usage=AnthropicUsage(
# input_tokens=origin_chunk.usage.prompt_tokens
# if origin_chunk.usage
# else 0,
# output_tokens=origin_chunk.usage.completion_tokens
# if origin_chunk.usage
# else 0,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "message_delta")
# continue
for
event
in
stop_active_block
():
yield
event
stop_reason
=
self
.
stop_reason_map
.
get
(
finish_reason
or
"stop"
)
...
...
@@ -366,29 +503,134 @@ class AnthropicServingMessages(OpenAIServingChat):
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"message_delta"
)
continue
# =========================================================
if
origin_chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason
=
origin_chunk
.
choices
[
0
].
finish_reason
continue
#
continue
# content
if
origin_chunk
.
choices
[
0
].
delta
.
content
is
not
None
:
if
not
content_block_started
:
# if origin_chunk.choices[0].delta.content is not None:
# if not content_block_started:
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_start",
# content_block=AnthropicContentBlock(
# type="text", text=""
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_start")
# content_block_started = True
# if origin_chunk.choices[0].delta.content == "":
# continue
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_delta",
# delta=AnthropicDelta(
# type="text_delta",
# text=origin_chunk.choices[0].delta.content,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_delta")
# continue
# tool calls
# elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
# elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
# tool_call = origin_chunk.choices[0].delta.tool_calls[0]
# if tool_call.id is not None:
# if content_block_started:
# stop_chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_stop",
# )
# data = stop_chunk.model_dump_json(
# exclude_unset=True
# )
# yield wrap_data_with_event(
# data, "content_block_stop"
# )
# content_block_started = False
# content_block_index += 1
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_start",
# content_block=AnthropicContentBlock(
# type="tool_use",
# id=tool_call.id,
# name=tool_call.function.name
# if tool_call.function
# else None,
# input={},
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_start")
# content_block_started = True
# else:
# chunk = AnthropicStreamEvent(
# index=content_block_index,
# type="content_block_delta",
# delta=AnthropicDelta(
# type="input_json_delta",
# partial_json=tool_call.function.arguments
# if tool_call.function
# else None,
# ),
# )
# data = chunk.model_dump_json(exclude_unset=True)
# yield wrap_data_with_event(data, "content_block_delta")
# continue
# thinking / text content
reasoning_delta
=
origin_chunk
.
choices
[
0
].
delta
.
reasoning
if
reasoning_delta
is
not
None
:
if
reasoning_delta
==
""
:
pass
else
:
if
active_block_type
!=
"thinking"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
AnthropicContentBlock
(
type
=
"thinking"
,
thinking
=
""
)
)
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
type
=
"content_block_start"
,
content_block
=
AnthropicContentBlock
(
type
=
"text"
,
text
=
""
index
=
(
active_block_index
if
active_block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"thinking_delta"
,
thinking
=
reasoning_delta
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_start"
)
content_block_started
=
True
yield
wrap_data_with_event
(
data
,
"content_block_delta"
)
if
origin_chunk
.
choices
[
0
].
delta
.
content
is
not
None
:
if
origin_chunk
.
choices
[
0
].
delta
.
content
==
""
:
continue
pass
else
:
if
active_block_type
!=
"text"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
AnthropicContentBlock
(
type
=
"text"
,
text
=
""
)
)
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
index
=
(
active_block_index
if
active_block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"text_delta"
,
...
...
@@ -397,55 +639,82 @@ class AnthropicServingMessages(OpenAIServingChat):
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_delta"
)
continue
# tool calls
el
if
len
(
origin_chunk
.
choices
[
0
].
delta
.
tool_calls
)
>
0
:
tool_call
=
origin_chunk
.
choices
[
0
].
delta
.
tool_calls
[
0
]
# tool calls
- process all tool calls in the delta
if
len
(
origin_chunk
.
choices
[
0
].
delta
.
tool_calls
)
>
0
:
for
tool_call
in
origin_chunk
.
choices
[
0
].
delta
.
tool_calls
:
if
tool_call
.
id
is
not
None
:
if
content_block_started
:
stop_chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
type
=
"content_block_stop"
,
)
data
=
stop_chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_stop"
# Update mapping for incremental updates
tool_index_to_id
[
tool_call
.
index
]
=
tool_call
.
id
# Only create new block if different tool call
# AND has a name
tool_name
=
(
tool_call
.
function
.
name
if
tool_call
.
function
else
None
)
content_block_started
=
False
content_block_index
+=
1
chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
type
=
"content_block_start"
,
content_block
=
AnthropicContentBlock
(
if
(
active_tool_use_id
!=
tool_call
.
id
and
tool_name
is
not
None
):
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
AnthropicContentBlock
(
type
=
"tool_use"
,
id
=
tool_call
.
id
,
name
=
tool_call
.
function
.
name
if
tool_call
.
function
else
None
,
name
=
tool_name
,
input
=
{},
)
)
yield
start_event
# Handle initial arguments if present
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
and
active_tool_use_id
==
tool_call
.
id
):
chunk
=
AnthropicStreamEvent
(
index
=
(
active_block_index
if
active_block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"input_json_delta"
,
partial_json
=
tool_call
.
function
.
arguments
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_start"
)
content_block_started
=
True
yield
wrap_data_with_event
(
data
,
"content_block_delta"
)
else
:
# Incremental update - use index to find tool_use_id
tool_use_id
=
tool_index_to_id
.
get
(
tool_call
.
index
)
if
(
tool_use_id
is
not
None
and
tool_call
.
function
and
tool_call
.
function
.
arguments
and
active_tool_use_id
==
tool_use_id
):
chunk
=
AnthropicStreamEvent
(
index
=
content_block_index
,
index
=
(
active_block_index
if
active_block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"input_json_delta"
,
partial_json
=
tool_call
.
function
.
arguments
if
tool_call
.
function
else
None
,
partial_json
=
tool_call
.
function
.
arguments
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"content_block_delta"
)
yield
wrap_data_with_event
(
data
,
"content_block_delta"
)
continue
else
:
error_response
=
AnthropicStreamEvent
(
...
...
@@ -468,3 +737,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data
=
error_response
.
model_dump_json
(
exclude_unset
=
True
)
yield
wrap_data_with_event
(
data
,
"error"
)
yield
"data: [DONE]
\n\n
"
async
def
count_tokens
(
self
,
request
:
AnthropicCountTokensRequest
,
raw_request
:
Request
|
None
=
None
,
)
->
AnthropicCountTokensResponse
|
ErrorResponse
:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req
=
self
.
_convert_anthropic_to_openai_request
(
request
)
result
=
await
self
.
render_chat_request
(
chat_req
)
if
isinstance
(
result
,
ErrorResponse
):
return
result
_
,
engine_prompts
=
result
input_tokens
=
sum
(
# type: ignore
len
(
prompt
[
"prompt_token_ids"
])
# type: ignore[typeddict-item, misc]
for
prompt
in
engine_prompts
if
"prompt_token_ids"
in
prompt
)
response
=
AnthropicCountTokensResponse
(
input_tokens
=
input_tokens
,
context_management
=
AnthropicContextManagement
(
original_input_tokens
=
input_tokens
),
)
return
response
\ No newline at end of file
vllm/entrypoints/openai/chat_completion/serving.py
View file @
fcc9c9ea
...
...
@@ -1239,10 +1239,13 @@ class OpenAIServingChat(OpenAIServing):
index
=
0
if
(
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
# self._should_check_for_unstreamed_tool_arg_tokens(
# delta_message, output
tool_parser
and
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
,
tool_parser
)
and
tool_parser
#
and tool_parser
):
latest_delta_len
=
0
if
(
...
...
@@ -1256,15 +1259,31 @@ class OpenAIServingChat(OpenAIServing):
latest_delta_len
=
len
(
delta_message
.
tool_calls
[
0
].
function
.
arguments
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call
=
json
.
dumps
(
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# expected_call = json.dumps(
# tool_parser.prev_tool_call_arr[index].get(
# "arguments", {}
# ),
# ensure_ascii=False,
# )
args
=
tool_parser
.
prev_tool_call_arr
[
index
].
get
(
"arguments"
,
{}
),
ensure_ascii
=
False
,
)
if
isinstance
(
args
,
str
):
expected_call
=
args
else
:
expected_call
=
json
.
dumps
(
args
,
ensure_ascii
=
False
)
# get what we've streamed so far for arguments
# for the current tool
...
...
@@ -1848,6 +1867,7 @@ class OpenAIServingChat(OpenAIServing):
self
,
delta_message
:
DeltaMessage
|
None
,
output
:
CompletionOutput
,
tool_parser
:
ToolParser
|
None
=
None
,
)
->
bool
:
"""
Check to see if we should check for unstreamed tool arguments tokens.
...
...
@@ -1866,6 +1886,8 @@ class OpenAIServingChat(OpenAIServing):
and
delta_message
.
tool_calls
[
0
]
and
delta_message
.
tool_calls
[
0
].
function
and
delta_message
.
tool_calls
[
0
].
function
.
arguments
is
not
None
and
tool_parser
is
not
None
and
tool_parser
.
parser_should_check_for_unstreamed_tool_arg_tokens
()
)
@
staticmethod
...
...
vllm/forward_context.py
View file @
fcc9c9ea
...
...
@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras
:
int
=
0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def
relax_for_mixed_batch_cudagraphs
(
self
)
->
"BatchDescriptor"
:
"""
...
...
vllm/lora/utils.py
View file @
fcc9c9ea
...
...
@@ -44,6 +44,23 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
get_captured_lora_counts
(
max_loras
:
int
,
specialize
:
bool
)
->
list
[
int
]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if
not
specialize
:
return
[
max_loras
+
1
]
return
[
n
for
n
in
range
(
1
,
max_loras
+
2
)
if
(
n
&
(
n
-
1
))
==
0
or
n
==
max_loras
+
1
]
_GLOBAL_LORA_ID
=
0
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
fcc9c9ea
...
...
@@ -1028,6 +1028,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
shared_output
=
None
,
routed_scaling_factor
=
None
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel
is
not
None
assert
not
self
.
is_monolithic
...
...
vllm/model_executor/models/step3p5_mtp.py
View file @
fcc9c9ea
...
...
@@ -52,7 +52,8 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self
.
enorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
hnorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
self
.
lm_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
Step3p5DecoderLayer
(
vllm_config
,
prefix
=
f
"
{
prefix
}
.mtp_block"
,
...
...
@@ -64,9 +65,13 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
embed_tokens
:
VocabParallelEmbedding
|
None
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
if
inputs_embeds
is
None
:
assert
embed_tokens
is
not
None
inputs_embeds
=
embed_tokens
(
input_ids
)
# assert inputs_embeds is not None
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
...
...
@@ -92,8 +97,10 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self
.
layers
=
torch
.
nn
.
ModuleDict
(
{
str
(
idx
):
Step3p5AMultiTokenPredictorLayer
(
vllm_config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
# vllm_config,
# f"{prefix}.layers.{idx}",
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
...
...
@@ -112,14 +119,15 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
#
if inputs_embeds is None:
#
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
self
.
embed_tokens
,
current_step_idx
,
)
...
...
@@ -131,7 +139,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared_head
.
head
,
mtp_layer
.
shared_head
(
hidden_states
)
# mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
mtp_layer
.
lm_head
.
head
,
mtp_layer
.
lm_head
(
hidden_states
)
)
return
logits
...
...
@@ -257,6 +266,7 @@ class Step3p5MTP(nn.Module):
name
=
name
.
replace
(
".transformer."
,
"."
)
if
"shared_head"
in
name
:
name
=
name
.
replace
(
"shared_head.output"
,
"shared_head.head"
)
name
=
name
.
replace
(
"shared_head"
,
"lm_head"
)
if
"embed_tokens"
in
name
:
assert
(
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
)
...
...
vllm/tool_parsers/abstract_tool_parser.py
View file @
fcc9c9ea
...
...
@@ -118,6 +118,11 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return
True
class
ToolParserManager
:
"""
...
...
vllm/tool_parsers/step3p5_tool_parser.py
View file @
fcc9c9ea
...
...
@@ -2,13 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
json
import
uuid
from
collections.abc
import
Sequence
from
typing
import
Any
from
xml.parsers.expat
import
ParserCreate
#
from xml.parsers.expat import ParserCreate
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
#
from vllm.entrypoints.chat_utils import make_tool_call_id
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
...
...
@@ -25,1487 +26,1142 @@ from vllm.logger import init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
,
)
logger
=
init_logger
(
__name__
)
class
StreamingXMLToolCallParser
:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
class
Step3p5ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
def
__init__
(
self
):
self
.
reset_streaming_state
()
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
# Override base class type - we use string IDs for tool calls
self
.
current_tool_id
:
str
|
None
=
None
# type: ignore
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
# Tool configuration information
self
.
tools
:
list
[
ChatCompletionToolsParam
]
|
None
=
None
# Sentinel tokens for streaming mode
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
self
.
function_start_token
:
str
=
"<function="
self
.
tool_call_prefix
:
str
=
"<function="
self
.
function_end_token
:
str
=
"</function>"
self
.
parameter_
start_token
:
str
=
"<parameter="
self
.
parameter_
prefix
:
str
=
"<parameter="
self
.
parameter_end_token
:
str
=
"</parameter>"
self
.
is_tool_call_started
:
bool
=
False
self
.
failed_count
:
int
=
0
def
reset_streaming_state
(
self
):
"""Reset streaming parsing state"""
self
.
deltas
=
[]
# state for streaming
self
.
tool_call_index
=
0
self
.
current_call_id
=
None
self
.
last_completed_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
streaming_buffer
=
""
self
.
last_processed_pos
=
0
self
.
text_content_buffer
=
""
# state for preprocessing and deferred parsing
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
# recreate parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
def
parse_single_streaming_chunks
(
self
,
xml_chunk
:
str
)
->
DeltaMessage
:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
# Enhanced streaming state - reset for each new message
self
.
_reset_streaming_state
()
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no complete elements
# Regex patterns
self
.
tool_call_complete_regex
=
re
.
compile
(
r
"<tool_call>(.*?)</tool_call>"
,
re
.
DOTALL
)
self
.
tool_call_function_regex
=
re
.
compile
(
r
"<function(?:=|\s+)?(.*?)</function>"
,
re
.
DOTALL
)
self
.
tool_call_parameter_regex
=
re
.
compile
(
r
"<parameter=(.*?)</parameter>"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
if
self
.
tool_call_start_token_id
is
None
or
self
.
tool_call_end_token_id
is
None
:
raise
RuntimeError
(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self
.
eos_token_id
=
getattr
(
self
.
model_tokenizer
,
"eos_token_id"
,
None
)
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
"""Generate a unique tool call ID."""
return
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
# Record delta count before processing
initial_delta_count
=
len
(
self
.
deltas
)
Skip the remaining_call calculation in serving
"""
return
False
self
.
streaming_buffer
+=
xml_chunk
def
_reset_streaming_state
(
self
):
"""Reset all streaming state for a new request."""
self
.
_processed_length
:
int
=
0
# Position of last processed character
self
.
_tool_call_index
:
int
=
0
# Number of tool calls processed so far
self
.
streaming_request
=
None
# Current request being processed
def
_get_arguments_config
(
self
,
func_name
:
str
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
)
->
dict
:
"""Extract argument configuration for a function."""
if
tools
is
None
:
return
{}
for
config
in
tools
:
if
not
hasattr
(
config
,
"type"
)
or
not
(
hasattr
(
config
,
"function"
)
and
hasattr
(
config
.
function
,
"name"
)
):
continue
if
config
.
type
==
"function"
and
config
.
function
.
name
==
func_name
:
if
not
hasattr
(
config
.
function
,
"parameters"
):
return
{}
params
=
config
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
return
params
[
"properties"
]
elif
isinstance
(
params
,
dict
):
return
params
else
:
return
{}
logger
.
warning
(
"Tool '%s' is not defined in the tools list."
,
func_name
)
return
{}
def
_convert_param_value
(
self
,
param_value
:
str
,
param_name
:
str
,
param_config
:
dict
,
func_name
:
str
)
->
Any
:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if
param_value
.
lower
()
==
"null"
:
return
None
found_elements
=
self
.
_process_complete_xml_elements
()
if
param_name
not
in
param_config
:
if
param_config
!=
{}:
logger
.
warning
(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value."
,
param_name
,
func_name
,
)
return
param_value
if
found_elements
:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try
:
new_deltas
=
self
.
deltas
[
initial_delta_count
:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if
(
self
.
current_call_id
is
not
None
and
self
.
function_end_token
in
xml_chunk
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]
):
param_type
=
str
(
param_config
[
param_name
][
"type"
]).
strip
().
lower
()
else
:
param_type
=
"string"
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
):
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
function
and
tc
.
id
==
self
.
current_call_id
and
isinstance
(
tc
.
function
.
arguments
,
str
)
and
(
tc
.
function
.
arguments
in
(
"}"
,
"{}"
))
try
:
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
try
:
float_value
=
float
(
param_value
)
if
float_value
.
is_integer
():
return
int
(
float_value
)
except
(
ValueError
,
TypeError
):
pass
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
bool
):
return
int
(
literal_value
)
if
isinstance
(
literal_value
,
(
int
,
float
)):
return
(
int
(
literal_value
)
if
float
(
literal_value
).
is_integer
()
else
literal_value
)
for
tc
in
td
.
tool_calls
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
=
float
(
param_value
)
return
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
)
for
td
in
new_deltas
except
(
ValueError
,
TypeError
):
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
(
int
,
float
)):
return
(
float
(
literal_value
)
if
float
(
literal_value
)
-
int
(
float
(
literal_value
))
!=
0
else
int
(
float
(
literal_value
))
)
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
normalized_value
=
param_value
.
strip
().
lower
()
if
normalized_value
in
[
"true"
,
"false"
]:
return
normalized_value
==
"true"
if
normalized_value
in
[
"1"
,
"0"
]:
return
normalized_value
==
"1"
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
bool
):
return
literal_value
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
if
not
has_function_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
return
param_value
else
:
if
(
self
.
current_call_id
is
not
None
and
self
.
tool_call_end_token
in
xml_chunk
param_type
in
[
"object"
,
"array"
,
"arr"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
):
has_toolcall_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
type
==
"function"
and
tc
.
function
and
tc
.
function
.
arguments
==
""
and
tc
.
id
==
self
.
current_call_id
)
for
tc
in
td
.
tool_calls
try
:
param_value
=
json
.
loads
(
param_value
)
return
param_value
except
(
json
.
JSONDecodeError
,
TypeError
,
ValueError
):
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
(
list
,
dict
)):
return
literal_value
if
isinstance
(
literal_value
,
(
tuple
,
set
)):
return
list
(
literal_value
)
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
# safer
if
isinstance
(
literal_value
,
(
tuple
,
set
)):
return
list
(
literal_value
)
if
(
isinstance
(
literal_value
,
(
list
,
dict
,
str
,
int
,
float
,
bool
))
or
literal_value
is
None
):
return
literal_value
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
for
td
in
new_deltas
return
param_value
def
_parse_parameters_fallback
(
self
,
parameters
:
str
,
allowed_param_names
:
set
[
str
]
|
None
=
None
,
)
->
list
[
tuple
[
str
,
str
]]:
"""Fallback parser for malformed parameter tags."""
param_pairs
:
list
[
tuple
[
str
,
str
]]
=
[]
pos
=
0
while
True
:
start
=
parameters
.
find
(
self
.
parameter_prefix
,
pos
)
if
start
==
-
1
:
break
name_start
=
start
+
len
(
self
.
parameter_prefix
)
name_end
=
parameters
.
find
(
">"
,
name_start
)
if
name_end
==
-
1
:
newline_idx
=
parameters
.
find
(
"
\n
"
,
name_start
)
end_tag
=
parameters
.
find
(
self
.
parameter_end_token
,
name_start
)
next_param
=
parameters
.
find
(
self
.
parameter_prefix
,
name_start
)
candidates
=
[
idx
for
idx
in
[
newline_idx
,
end_tag
,
next_param
]
if
idx
!=
-
1
]
if
not
candidates
:
break
name_end
=
min
(
candidates
)
value_start
=
name_end
else
:
value_start
=
name_end
+
1
param_name
=
parameters
[
name_start
:
name_end
].
strip
()
next_param
=
parameters
.
find
(
self
.
parameter_prefix
,
value_start
)
end_tag
=
parameters
.
find
(
self
.
parameter_end_token
,
value_start
)
if
end_tag
==
-
1
or
(
next_param
!=
-
1
and
next_param
<
end_tag
):
end
=
next_param
if
next_param
!=
-
1
else
len
(
parameters
)
pos
=
end
else
:
end
=
end_tag
pos
=
end
+
len
(
self
.
parameter_end_token
)
param_value
=
parameters
[
value_start
:
end
]
if
allowed_param_names
is
None
or
param_name
in
allowed_param_names
:
param_pairs
.
append
((
param_name
,
param_value
))
return
param_pairs
def
_is_valid_json_arguments
(
self
,
arguments
:
str
)
->
bool
:
"""Check if arguments can be loaded as JSON."""
try
:
json
.
loads
(
arguments
)
except
Exception
:
return
False
return
True
def
_parse_xml_function_call
(
self
,
function_call_str
:
str
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
)
->
ToolCall
|
None
:
# Extract function name
end_index
=
function_call_str
.
index
(
">"
)
# check empty function name
function_name
=
function_call_str
[:
end_index
].
strip
()
if
function_name
.
startswith
(
"="
):
function_name
=
function_name
.
lstrip
(
"="
).
strip
()
if
not
function_name
or
function_name
.
strip
(
"'
\"
"
)
==
""
:
logger
.
warning
(
"Empty function name in tool call."
)
return
None
if
function_name
[
0
]
in
"
\"
'"
and
function_name
[
-
1
]
==
function_name
[
0
]:
function_name
=
function_name
[
1
:
-
1
].
strip
()
if
not
function_name
:
logger
.
warning
(
"Empty function name in tool call."
)
return
None
param_config
=
self
.
_get_arguments_config
(
function_name
,
tools
)
parameters
=
function_call_str
[
end_index
+
1
:]
param_dict
=
{}
match_texts
=
self
.
tool_call_parameter_regex
.
findall
(
parameters
)
use_fallback
=
False
if
match_texts
:
for
match_text
in
match_texts
:
if
self
.
parameter_prefix
in
match_text
or
">"
not
in
match_text
:
use_fallback
=
True
break
else
:
use_fallback
=
self
.
parameter_prefix
in
parameters
if
use_fallback
:
allowed_param_names
=
(
set
(
param_config
.
keys
())
if
isinstance
(
param_config
,
dict
)
and
param_config
else
None
)
if
not
has_toolcall_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
self
.
_end_element
(
"tool_call"
)
except
Exception
as
e
:
logger
.
warning
(
"Error with fallback parsing: %s"
,
e
)
# Merge newly generated deltas into single response
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
param_pairs
=
self
.
_parse_parameters_fallback
(
parameters
,
allowed_param_names
)
return
result_delta
else
:
# No complete elements, check if there's unoutput text content
if
self
.
text_content_buffer
and
self
.
tool_call_index
==
0
:
# Has text content but no tool_call yet, output text content
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer to avoid duplicate output
self
.
text_content_buffer
=
""
return
text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if
self
.
current_call_id
is
not
None
and
(
self
.
function_end_token
in
xml_chunk
or
self
.
tool_call_end_token
in
xml_chunk
):
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
function_end_token
in
xml_chunk
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
if
self
.
tool_call_end_token
in
xml_chunk
:
self
.
_end_element
(
"tool_call"
)
# Return the merged delta result generated by this fallback
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
param_pairs
=
[]
for
match_text
in
match_texts
:
idx
=
match_text
.
index
(
">"
)
param_name
=
match_text
[:
idx
]
param_value
=
str
(
match_text
[
idx
+
1
:])
param_pairs
.
append
((
param_name
,
param_value
))
for
param_name
,
param_value
in
param_pairs
:
# Remove prefix and trailing \n
if
param_value
.
startswith
(
"
\n
"
):
param_value
=
param_value
[
1
:]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
param_dict
[
param_name
]
=
self
.
_convert_param_value
(
param_value
,
param_name
,
param_config
,
function_name
)
return
result_delta
# No complete elements, return empty response
return
DeltaMessage
(
content
=
None
)
try
:
arguments
=
json
.
dumps
(
param_dict
,
ensure_ascii
=
False
)
except
Exception
as
e
:
logger
.
warning
(
"Error in converting parameter value: %s"
,
e
)
return
None
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
arguments
),
)
def
_escape_xml_special_chars
(
self
,
text
:
str
)
->
str
:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes
=
{
"&"
:
"&"
,
"<"
:
"<"
,
">"
:
">"
,
'"'
:
"""
,
"'"
:
"'"
,
}
def
_get_function_calls
(
self
,
model_output
:
str
)
->
list
[
str
]:
# Find all tool calls
raw_tool_calls
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
for
char
,
escape
in
xml_escapes
.
items
():
text
=
text
.
replace
(
char
,
escape
)
# if no closed tool_call tags found, return empty list
if
len
(
raw_tool_calls
)
==
0
:
return
[]
return
text
raw_function_calls
=
[]
for
tool_call
in
raw_tool_calls
:
function_matches
=
self
.
tool_call_function_regex
.
findall
(
tool_call
)
raw_function_calls
.
extend
(
function_matches
)
def
_process_complete_xml_elements
(
self
)
->
bool
:
"""
Process complete XML elements in buffer
return
raw_function_calls
Returns:
bool: Whether complete elements were found and processed
"""
found_any
=
False
def
_check_format
(
self
,
model_output
:
str
)
->
bool
:
"""Check if model output contains properly formatted tool call.
while
self
.
last_processed_pos
<
len
(
self
.
streaming_buffer
):
# Find next complete xml element
element
,
end_pos
=
self
.
_find_next_complete_element
(
self
.
last_processed_pos
)
if
element
is
None
:
# No complete element found, wait for more data
break
Requirements:
1. Must have closed tool_call tags (<tool_call>...</tool_call>)
2. Must have closed function tags (<function=...</function>)
3. If parameter tags exist, they must be closed and correct
# Check if this element should be skipped
if
self
.
_should_skip_element
(
element
):
self
.
last_processed_pos
=
end_pos
continue
Returns True if the format is valid, False otherwise.
"""
# Check 1: Must have closed tool_call tags
tool_call_matches
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
if
len
(
tool_call_matches
)
==
0
:
return
False
# Found complete XML element, process it
try
:
preprocessed_element
=
self
.
_preprocess_xml_chunk
(
element
)
# Check if this is the first tool_call start
# Check 2: Must have closed function tags within tool_call
has_valid_function
=
False
for
tool_call_content
in
tool_call_matches
:
function_matches
=
self
.
tool_call_function_regex
.
findall
(
tool_call_content
)
if
len
(
function_matches
)
>
0
:
has_valid_function
=
True
# Check if there's an unclosed function tag
if
(
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
or
preprocessed_element
.
strip
().
startswith
(
"<function name="
)
self
.
tool_call_prefix
in
tool_call_content
and
self
.
function_end_token
not
in
tool_call_content
):
return
False
if
not
has_valid_function
:
return
False
# Check 3: If parameter tags exist, they must be closed and correct
for
tool_call_content
in
tool_call_matches
:
# Count opening and closing parameter tags
param_open_count
=
tool_call_content
.
count
(
self
.
parameter_prefix
)
param_close_count
=
tool_call_content
.
count
(
self
.
parameter_end_token
)
# If there are parameter tags, they must be balanced
if
param_open_count
>
0
:
if
param_open_count
!=
param_close_count
:
return
False
# Check if all parameter tags are properly closed using regex
param_matches
=
self
.
tool_call_parameter_regex
.
findall
(
tool_call_content
)
and
self
.
tool_call_index
==
0
)
and
self
.
text_content_buffer
:
# First tool_call starts,
# output previously collected text content first
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer for potential subsequent text content
self
.
text_content_buffer
=
""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if
len
(
param_matches
)
!=
param_open_count
:
return
False
return
True
def
_wrap_missing_tool_call_tags
(
self
,
model_output
:
str
)
->
str
:
"""Wrap bare <function=...></function> blocks with <tool_call> tags."""
if
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
and
self
.
tool_call_index
>
0
and
self
.
current_call_id
and
self
.
current_function_name
self
.
tool_call_prefix
not
in
model_output
or
self
.
function_end_token
not
in
model_output
):
# Reset parser state but preserve generated deltas
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_open
:
self
.
_end_element
(
"function"
)
# Output final tool_call tail delta
final_delta
=
DeltaMessage
(
role
=
None
,
content
=
None
,
reasoning_content
=
None
,
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
)
],
)
self
.
_emit_delta
(
final_delta
)
# Reset XML parser and current call state
self
.
_reset_xml_parser_after_tool_call
()
# Parse preprocessed element
self
.
parser
.
Parse
(
preprocessed_element
,
False
)
found_any
=
True
return
model_output
def
_wrap_bare_functions
(
text
:
str
)
->
str
:
pos
=
0
wrapped_parts
:
list
[
str
]
=
[]
while
True
:
func_idx
=
text
.
find
(
self
.
tool_call_prefix
,
pos
)
if
func_idx
==
-
1
:
wrapped_parts
.
append
(
text
[
pos
:])
break
end_idx
=
text
.
find
(
self
.
function_end_token
,
func_idx
)
if
end_idx
==
-
1
:
wrapped_parts
.
append
(
text
[
pos
:])
break
end_idx
+=
len
(
self
.
function_end_token
)
wrapped_parts
.
append
(
text
[
pos
:
func_idx
])
wrapped_parts
.
append
(
self
.
tool_call_start_token
)
wrapped_parts
.
append
(
text
[
func_idx
:
end_idx
])
wrapped_parts
.
append
(
self
.
tool_call_end_token
)
ws_idx
=
end_idx
while
ws_idx
<
len
(
text
)
and
text
[
ws_idx
].
isspace
():
ws_idx
+=
1
if
text
.
startswith
(
self
.
tool_call_end_token
,
ws_idx
):
if
ws_idx
>
end_idx
:
wrapped_parts
.
append
(
text
[
end_idx
:
ws_idx
])
pos
=
ws_idx
+
len
(
self
.
tool_call_end_token
)
else
:
pos
=
end_idx
return
""
.
join
(
wrapped_parts
)
except
Exception
as
e
:
logger
.
warning
(
"Error when parsing XML elements: %s"
,
e
)
tool_call_ranges
=
[
match
.
span
()
for
match
in
self
.
tool_call_complete_regex
.
finditer
(
model_output
)
]
if
not
tool_call_ranges
:
return
_wrap_bare_functions
(
model_output
)
# Update processed position
self
.
last_processed_pos
=
end_pos
wrapped_parts
:
list
[
str
]
=
[]
pos
=
0
for
start
,
end
in
tool_call_ranges
:
if
start
<
pos
:
continue
wrapped_parts
.
append
(
_wrap_bare_functions
(
model_output
[
pos
:
start
]))
wrapped_parts
.
append
(
model_output
[
start
:
end
])
pos
=
end
wrapped_parts
.
append
(
_wrap_bare_functions
(
model_output
[
pos
:]))
return
""
.
join
(
wrapped_parts
)
def
_normalize_prev_arguments
(
self
,
args_value
:
Any
)
->
Any
:
if
isinstance
(
args_value
,
str
):
try
:
return
json
.
loads
(
args_value
)
except
(
TypeError
,
ValueError
,
json
.
JSONDecodeError
):
return
args_value
return
args_value
def
_update_prev_tool_call_state
(
self
,
tool_calls
:
list
[
ToolCall
])
->
None
:
self
.
prev_tool_call_arr
.
clear
()
self
.
streamed_args_for_tool
.
clear
()
for
tool_call
in
tool_calls
:
if
not
tool_call
or
not
tool_call
.
function
:
continue
args_value
=
tool_call
.
function
.
arguments
if
isinstance
(
args_value
,
str
):
args_json
=
args_value
elif
args_value
is
None
:
args_json
=
""
else
:
try
:
args_json
=
json
.
dumps
(
args_value
,
ensure_ascii
=
False
)
except
(
TypeError
,
ValueError
):
args_json
=
str
(
args_value
)
prev_args
=
self
.
_normalize_prev_arguments
(
args_json
)
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
prev_args
,
}
)
try
:
expected_args_json
=
json
.
dumps
(
prev_args
,
ensure_ascii
=
False
)
except
(
TypeError
,
ValueError
):
expected_args_json
=
args_json
return
found_any
# Serving may subtract the latest delta length from
# streamed_args_for_tool to detect unstreamed suffixes. Since this
# parser emits full arguments at once, store expected+actual so
# the subtraction yields expected_args_json and no resend occurs.
self
.
streamed_args_for_tool
.
append
(
expected_args_json
+
args_json
)
def
_fix_incomplete_tag_in_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk
=
self
.
_fix_missing_equals_in_function_tag
(
chunk
)
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
try
:
origin_model_output
=
model_output
try
:
# Fallback: handle outputs without <tool_call> wrapper.
origin_model_output
=
self
.
_wrap_missing_tool_call_tags
(
origin_model_output
)
model_output
=
origin_model_output
except
Exception
:
pass
for
tag_type
in
[
"parameter"
,
"func
tion
"
]:
pattern
=
f
"<
{
tag_type
}
="
if
pattern
not
in
chunk
:
continue
# Use streaming-like approach: process position by posi
tion
valid_tool_calls
=
[]
content_parts
=
[]
processed_length
=
0
start_idx
=
chunk
.
find
(
pattern
)
after_tag
=
chunk
[
start_idx
:]
gt_pos
=
after_tag
.
find
(
">"
)
lt_pos
=
after_tag
.
find
(
"<"
,
len
(
pattern
))
while
processed_length
<
len
(
model_output
):
# Find next tool call start
tool_start_idx
=
self
.
_find_tool_call_start
(
model_output
,
processed_length
)
# Skip if already well-formed
# Case 1: No more tool calls - add remaining as content
if
tool_start_idx
==
-
1
:
remaining
=
model_output
[
processed_length
:]
if
remaining
:
content_parts
.
append
(
remaining
)
break
# Case 2: Content before tool call
if
tool_start_idx
>
processed_length
:
content_before
=
model_output
[
processed_length
:
tool_start_idx
]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if
processed_length
>
0
:
text_before
=
model_output
[:
processed_length
]
if
(
gt_pos
!=
-
1
and
(
lt_pos
==
-
1
or
gt_pos
<
lt_pos
)
and
pattern
in
after_tag
[:
gt_pos
]
text_before
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
content_before
.
strip
()
==
""
):
continue
# Skip whitespace between tool calls
pass
else
:
content_parts
.
append
(
content_before
)
else
:
content_parts
.
append
(
content_before
)
# Extract tag name (stop at space, newline, or <)
content
=
chunk
[
start_idx
+
len
(
pattern
)
:]
end_pos
=
next
(
(
i
for
i
,
ch
in
enumerate
(
content
)
if
ch
in
(
" "
,
"
\n
"
,
"<"
)),
len
(
content
),
# Case 3: Try to find complete tool call
tool_end_idx
=
self
.
_find_first_complete_tool_call_end
(
model_output
,
tool_start_idx
)
tag_name
=
content
[:
end_pos
]
if
not
tag_name
:
continue
# If tool call is incomplete - add remaining as content and stop
if
tool_end_idx
==
-
1
:
remaining
=
model_output
[
tool_start_idx
:]
if
remaining
:
content_parts
.
append
(
remaining
)
break
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if
tag_name
.
startswith
(
f
"
{
tag_type
}
="
):
tag_name
=
tag_name
[
len
(
tag_type
)
+
1
:]
# Extract and try to parse the complete tool call
tool_call_text
=
model_output
[
tool_start_idx
:
tool_end_idx
]
parsed_result
=
self
.
extract_tool_calls_basic
(
tool_call_text
,
request
)
# Remove trailing non-alphanumeric chars (keep - and _)
while
tag_name
and
not
(
tag_name
[
-
1
].
isalnum
()
or
tag_name
[
-
1
]
in
(
"-"
,
"_"
)
):
tag_name
=
tag_name
[:
-
1
]
# If parsing succeeded, record the tool call(s)
if
parsed_result
.
tools_called
and
parsed_result
.
tool_calls
:
valid_tool_calls
.
extend
(
parsed_result
.
tool_calls
)
processed_length
=
tool_end_idx
else
:
# Parsing failed - treat this tool call as content
content_parts
.
append
(
tool_call_text
)
processed_length
=
tool_end_idx
if
not
tag_name
:
continue
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
_update_prev_tool_call_state
(
valid_tool_calls
)
# Validate parameter exists in tool definition
if
tag_type
==
"parameter"
and
not
self
.
_validate_parameter_name
(
tag_name
):
continue
# Combine content parts
content
=
""
.
join
(
content_parts
)
if
content_parts
else
None
# Apply fix
chunk
=
chunk
.
replace
(
f
"<
{
tag_type
}
=
{
content
[:
end_pos
]
}
"
,
f
"<
{
tag_type
}
=
{
tag_name
}
>"
,
1
return
ExtractedToolCallInformation
(
tools_called
=
(
len
(
valid_tool_calls
)
>
0
),
tool_calls
=
valid_tool_calls
,
content
=
content
if
content
else
None
,
)
except
Exception
:
logger
.
warning
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
return
chunk
def
extract_tool_calls_basic
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
model_output
=
self
.
_wrap_missing_tool_call_tags
(
model_output
)
# Quick check to avoid unnecessary processing
if
not
self
.
_check_format
(
model_output
):
tool_call_matches
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
if
len
(
tool_call_matches
)
==
0
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_fix_missing_equals_in_function_tag
(
self
,
chunk
:
str
)
->
str
:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if
"<function="
in
chunk
:
return
chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1
=
r
"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1
=
re
.
search
(
pattern1
,
chunk
)
if
match1
:
func_name
=
match1
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match1
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2
=
r
"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2
=
re
.
search
(
pattern2
,
chunk
)
if
match2
:
func_name
=
match2
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match2
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
return
chunk
def
_validate_function_name
(
self
,
func_name
:
str
)
->
bool
:
"""Check if function name exists in tool definitions"""
if
not
self
.
tools
:
return
False
try
:
function_calls
=
self
.
_get_function_calls
(
model_output
)
if
len
(
function_calls
)
==
0
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
for
tool
in
self
.
tools
:
tool_calls
:
list
[
ToolCall
]
=
[]
for
function_call_str
in
function_calls
:
tool_call
=
self
.
_parse_xml_function_call
(
function_call_str
,
request
.
tools
)
if
tool_call
:
tool_calls
.
append
(
tool_call
)
if
not
tool_calls
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
for
tool_call
in
tool_calls
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
func_name
not
tool_call
.
function
or
tool_call
.
function
.
arguments
is
None
or
not
self
.
_is_valid_json_arguments
(
tool_call
.
function
.
arguments
)
):
return
True
logger
.
warning
(
"Invalid JSON arguments in tool call, falling back to content."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
return
False
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
_update_prev_tool_call_state
(
tool_calls
)
def
_validate_parameter_name
(
self
,
param_name
:
str
)
->
bool
:
"""Check if parameter exists in current function's tool definition"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
True
# Extract content before tool calls
content_index
=
model_output
.
find
(
self
.
tool_call_start_token
)
content
=
model_output
[:
content_index
]
# .rstrip()
for
tool
in
self
.
tools
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
self
.
current_function_name
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
True
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
):
properties
=
params
.
get
(
"properties"
,
params
)
return
param_name
in
properties
break
return
ExtractedToolCallInformation
(
tools_called
=
(
len
(
tool_calls
)
>
0
),
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
,
)
return
True
except
Exception
:
logger
.
warning
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_should_skip_element
(
self
,
element
:
str
)
->
bool
:
"""
Determine whether an element should be skipped
def
_find_first_complete_tool_call_end
(
self
,
text
:
str
,
start_pos
:
int
=
0
)
->
int
:
"""Find the end position of the first complete tool call.
Args:
element: Element to evaluate
text: Text to search in
start_pos: Position to start searching from
Returns:
bool: True means should skip, False means should process
"""
Position after the first </tool_call> tag, or -1 if incomplete
# If it's a tool_call XML tag, don't skip
if
(
element
.
startswith
(
self
.
tool_call_start_token
)
or
element
.
startswith
(
self
.
function_start_token
)
or
element
.
startswith
(
self
.
parameter
_start_token
)
)
:
return
False
Example:
"<tool_call>...</tool_call>..." returns position after </tool_call>
"""
# Find tool call start
start_idx
=
text
.
find
(
self
.
tool_call
_start_token
,
start_pos
)
if
start_idx
==
-
1
:
return
-
1
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if
self
.
current_call_id
is
None
and
element
:
# Collect text content to buffer
self
.
text_content_buffer
+=
element
return
True
# Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if
self
.
current_call_id
is
not
None
:
return
False
# Find matching end token
end_idx
=
text
.
find
(
self
.
tool_call_end_token
,
start_idx
+
len
(
self
.
tool_call_start_token
)
)
if
end_idx
==
-
1
:
return
-
1
# Incomplete tool call
#
Skip blank cont
en
t
return
not
element
#
Return position after end tok
en
return
end_idx
+
len
(
self
.
tool_call_end_token
)
def
_find_next_complete_element
(
self
,
start_pos
:
int
)
->
tuple
[
str
|
None
,
int
]:
"""
Find next complete XML element from specified position
def
_find_tool_call_start
(
self
,
text
:
str
,
start_pos
:
int
=
0
)
->
int
:
"""Find the start position of next tool call.
Args:
start_pos: Position to start searching
text: Text to search in
start_pos: Position to start searching from
Returns:
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
Position of <tool_call> token, or -1 if not found
"""
buffer
=
self
.
streaming_buffer
[
start_pos
:]
return
text
.
find
(
self
.
tool_call_start_token
,
start_pos
)
if
not
buffer
:
return
None
,
start_pos
def
_extract_content_between_tool_calls_list
(
self
,
text
:
str
)
->
list
[
str
]
:
"""Extract content segments after each tool call.
if
buffer
.
startswith
(
"<"
):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param
=
(
buffer
.
startswith
(
"<parameter="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
is_incomplete_func
=
(
buffer
.
startswith
(
"<function="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
For n tool calls, returns n segments where segment[i] is the content
after tool_call[i] (before tool_call[i+1] or at the end).
if
is_incomplete_param
or
is_incomplete_func
:
# Find the corresponding closing tag
tag_type
=
"parameter"
if
is_incomplete_param
else
"function"
closing_tag
=
f
"</
{
tag_type
}
>"
closing_pos
=
buffer
.
find
(
closing_tag
)
if
closing_pos
!=
-
1
:
# Found closing tag, return complete element including closing tag
complete_element
=
buffer
[:
closing_pos
+
len
(
closing_tag
)]
return
complete_element
,
start_pos
+
closing_pos
+
len
(
closing_tag
)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end
=
buffer
.
find
(
"<"
,
1
)
tag_end2
=
buffer
.
find
(
">"
,
1
)
if
tag_end
!=
-
1
and
tag_end2
!=
-
1
:
# Next nearest is <
if
tag_end
<
tag_end2
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
# Next nearest is >, means found XML element
else
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
elif
tag_end
!=
-
1
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
elif
tag_end2
!=
-
1
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
else
:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if
self
.
current_call_id
is
None
:
# Check if might be start of <tool_call>
if
buffer
==
"<tool_call>"
[:
len
(
buffer
)]:
# Might be start of <tool_call>, wait for more data
return
None
,
start_pos
elif
(
buffer
.
startswith
(
"<function="
)
or
buffer
==
"<function="
[:
len
(
buffer
)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return
None
,
start_pos
else
:
# Not start of <tool_call> or <function=, treat as text
return
buffer
,
start_pos
+
len
(
buffer
)
else
:
# When parsing tool calls,
# wait for more data to get complete tag
return
None
,
start_pos
else
:
# Find text content (until next < or buffer end)
next_tag_pos
=
buffer
.
find
(
"<"
)
if
next_tag_pos
!=
-
1
:
# Found text content
text_content
=
buffer
[:
next_tag_pos
]
return
text_content
,
start_pos
+
next_tag_pos
else
:
# Buffer end is all text, process
# (no longer wait for more data)
remaining
=
buffer
return
remaining
,
start_pos
+
len
(
remaining
)
def
_merge_new_deltas_to_single_response
(
self
,
initial_count
:
int
)
->
DeltaMessage
:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Empty or whitespace-only segments are represented as empty string "".
Args:
initial_count: Delta
co
u
nt
before processing
text: Text
cont
aining tool calls
Returns:
Merged DeltaMessage containing all newly generated delta information
List of content segments (one per tool call)
"""
if
len
(
self
.
deltas
)
<=
initial_count
:
return
DeltaMessage
(
content
=
None
)
content_segments
=
[]
pos
=
0
# Get newly generated deltas
new_deltas
=
self
.
deltas
[
initial_count
:]
while
True
:
# Find end of current tool call
end_pos
=
text
.
find
(
self
.
tool_call_end_token
,
pos
)
if
end_pos
==
-
1
:
break
if
len
(
new_deltas
)
==
1
:
# Only one new delta, return directly
return
new_deltas
[
0
]
# Move past the end token
end_pos
+=
len
(
self
.
tool_call_end_token
)
# Merge multiple new deltas
merged_tool_calls
:
list
[
DeltaToolCall
]
=
[]
merged_content
:
str
=
""
# Find start of next tool call
next_start
=
self
.
_find_tool_call_start
(
text
,
end_pos
)
for
delta
in
new_deltas
:
if
delta
.
content
:
merged_content
+=
delta
.
content
if
delta
.
tool_calls
:
# For tool_calls, we need to intelligently merge arguments
for
tool_call
in
delta
.
tool_calls
:
# Find if there's already a tool_call with the same call_id
existing_call
=
None
for
existing
in
merged_tool_calls
:
if
existing
.
id
==
tool_call
.
id
:
existing_call
=
existing
# Extract content between current end and next start (or text end)
content
=
text
[
end_pos
:
next_start
]
if
next_start
!=
-
1
else
text
[
end_pos
:]
# Store content (empty string if whitespace-only)
content_segments
.
append
(
content
if
content
.
strip
()
else
""
)
if
next_start
==
-
1
:
break
pos
=
next_start
if
existing_call
and
existing_call
.
function
:
# Merge to existing tool_call
if
tool_call
.
function
and
tool_call
.
function
.
name
:
existing_call
.
function
.
name
=
tool_call
.
function
.
name
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
is
not
None
):
if
existing_call
.
function
.
arguments
is
None
:
existing_call
.
function
.
arguments
=
""
# For streaming JSON parameters,
# simply concatenate in order
new_args
=
tool_call
.
function
.
arguments
existing_call
.
function
.
arguments
+=
new_args
if
tool_call
.
type
:
existing_call
.
type
=
tool_call
.
type
else
:
# Add new tool_call
merged_tool_calls
.
append
(
tool_call
)
return
content_segments
return
DeltaMessage
(
content
=
merged_content
if
merged_content
else
None
,
tool_calls
=
merged_t
ool
_c
all
s
,
)
def
_convert_tool_calls_to_deltas
(
self
,
tool_calls
:
list
[
ToolCall
],
starting_index
:
int
=
0
)
->
list
[
DeltaT
ool
C
all
]:
"""Convert complete ToolCall list to DeltaToolCall list.
def
_preprocess_xml_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
Returns complete tool calls without splitting into fragments.
Args:
chunk: Original XML chunk
tool_calls: List of tool calls to convert
starting_index: Starting index for tool calls (default 0)
Returns:
Processed XML chunk
List of DeltaToolCall with complete arguments
"""
delta_tool_calls
=
[]
for
i
,
tool_call
in
enumerate
[
ToolCall
](
tool_calls
):
index
=
starting_index
+
i
tool_id
=
self
.
_generate_tool_call_id
()
# Check if this is a tool_call related element
is_tool_call
=
False
if
chunk
.
startswith
(
self
.
tool_call_start_token
)
or
chunk
.
startswith
(
self
.
tool_call_end_token
):
is_tool_call
=
True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if
(
chunk
.
startswith
(
self
.
function_start_token
)
or
chunk
.
startswith
(
self
.
function_end_token
)
or
chunk
.
startswith
(
"<function "
)
or
re
.
match
(
r
"^<function[a-zA-Z_]"
,
chunk
)
):
# <functionXXX without space or =
is_tool_call
=
True
if
chunk
.
startswith
(
self
.
parameter_start_token
)
or
chunk
.
startswith
(
self
.
parameter_end_token
):
is_tool_call
=
True
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if
(
self
.
current_call_id
is
not
None
or
chunk
.
startswith
(
"<function"
)
or
chunk
.
startswith
(
"<parameter"
)
):
chunk
=
self
.
_fix_incomplete_tag_in_chunk
(
chunk
)
# Handle <function=name> format -> <function name="name">
processed
=
re
.
sub
(
r
"<function=([^>]+)>"
,
r
'<function name="\1">'
,
chunk
)
# Handle <parameter=name> format -> <parameter name="name">
processed
=
re
.
sub
(
r
"<parameter=([^>]+)>"
,
r
'<parameter name="\1">'
,
processed
)
original_chunk
=
chunk
# If in parameter value accumulation mode
if
self
.
_pre_inside_parameter
:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if
processed
.
startswith
(
"</parameter>"
):
body_text
=
self
.
_pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self
.
defer_current_parameter
=
True
self
.
deferred_param_raw_value
=
body_text
# Clean up state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
safe_text
=
self
.
_escape_xml_special_chars
(
body_text
)
return
f
"
{
safe_text
}
</parameter>"
else
:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if
self
.
_pre_param_buffer
==
""
:
# Get current parameter type
param_type
=
(
self
.
_get_param_type
(
self
.
_pre_current_param_name
)
if
self
.
_pre_current_param_name
else
"string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type
=
param_type
in
[
"object"
]
is_complex_type
=
(
param_type
in
[
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint
=
(
(
"["
in
original_chunk
)
or
(
"{"
in
original_chunk
)
or
(
"("
in
original_chunk
)
)
# Determine if deferred parsing is needed
need_defer
=
False
if
is_complex_type
:
# Complex type, always need deferred parsing
need_defer
=
True
elif
(
is_object_type
and
has_container_hint
and
(
"'"
in
original_chunk
)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer
=
True
if
not
need_defer
:
# No need for deferred parsing,
# exit parameter mode directly
self
.
_pre_inside_parameter
=
False
return
self
.
_escape_xml_special_chars
(
original_chunk
)
self
.
_pre_param_buffer
+=
original_chunk
return
""
# Parameter start: enable accumulation
if
processed
.
startswith
(
"<parameter name="
):
m
=
re
.
match
(
r
'<parameter name="([^"]+)">'
,
processed
)
if
m
:
self
.
_pre_current_param_name
=
m
.
group
(
1
)
self
.
_pre_inside_parameter
=
True
self
.
_pre_param_buffer
=
""
return
processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if
not
is_tool_call
:
processed
=
self
.
_escape_xml_special_chars
(
processed
)
return
processed
def
_emit_delta
(
self
,
delta
:
DeltaMessage
):
"""Emit Delta response (streaming output)"""
self
.
deltas
.
append
(
delta
)
def
_auto_close_open_parameter_if_needed
(
self
,
incoming_tag
:
str
|
None
=
None
):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
"""
# First close unclosed parameters
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if
incoming_tag
in
(
"function"
,
"tool_call"
)
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if
incoming_tag
==
"tool_call"
and
self
.
current_call_id
:
self
.
_end_element
(
"tool_call"
)
def
_start_element
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
]):
"""Handle XML start element events"""
if
name
==
"root"
:
return
if
name
==
"tool_call"
:
# Before opening new tool_call,
# automatically complete previous unclosed tags
self
.
_auto_close_open_parameter_if_needed
(
"tool_call"
)
self
.
parameters
=
{}
self
.
current_call_id
=
make_tool_call_id
()
self
.
current_param_is_first
=
True
self
.
tool_call_index
+=
1
elif
name
.
startswith
(
"function"
)
or
(
name
==
"function"
):
# If missing tool_call, manually complete
if
not
self
.
current_call_id
:
self
.
_start_element
(
"tool_call"
,
{})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self
.
_auto_close_open_parameter_if_needed
(
"function"
)
function_name
=
self
.
_extract_function_name
(
name
,
attrs
)
self
.
current_function_name
=
function_name
self
.
current_function_open
=
True
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
function_name
,
arguments
=
""
),
)
]
)
self
.
_emit_delta
(
delta
)
elif
name
.
startswith
(
"parameter"
)
or
(
name
==
"parameter"
):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self
.
_auto_close_open_parameter_if_needed
(
"parameter"
)
param_name
=
self
.
_extract_parameter_name
(
name
,
attrs
)
self
.
current_param_name
=
param_name
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
# Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if
param_name
:
if
not
self
.
parameters
:
# First parameter
# start JSON, only output parameter name and colon
json_start
=
f
'{{"
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_start
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
True
else
:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue
=
f
', "
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
# Create complete DeltaToolCall with full arguments
delta_tool_calls
.
append
(
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
index
=
index
,
id
=
tool_id
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_continue
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
False
def
_char_data
(
self
,
data
:
str
):
"""Handle XML character data events"""
if
data
and
self
.
current_param_name
:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if
self
.
defer_current_parameter
:
original_data
=
data
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
return
param_type
=
self
.
_get_param_type
(
self
.
current_param_name
)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if
not
self
.
current_param_value
and
data
.
startswith
(
"
\n
"
):
data
=
data
[
1
:]
# Output start quote for string type (if not already output)
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
and
not
self
.
start_quote_emitted
):
quote_delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
)
self
.
_emit_delta
(
quote_delta
)
self
.
start_quote_emitted
=
True
if
not
data
:
return
original_data
=
data
# Delay output of trailing newline
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
# convert parameter value by param_type
converted_value
=
self
.
_convert_param_value
(
self
.
current_param_value
,
param_type
)
output_data
=
self
.
_convert_for_json_streaming
(
converted_value
,
param_type
)
delta_data
=
output_data
[
len
(
self
.
current_param_value_converted
)
:]
self
.
current_param_value_converted
=
output_data
return
delta_tool_calls
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
delta_data
),
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if
not
previous_text
:
self
.
_reset_streaming_state
()
self
.
streaming_request
=
request
# Check for EOS token
has_eos
=
(
self
.
eos_token_id
is
not
None
and
delta_token_ids
and
self
.
eos_token_id
in
delta_token_ids
)
]
# If no delta text, check if we need to return empty delta for finish_reason
if
not
delta_text
and
not
has_eos
:
# Check if this is an EOS token after all tool calls are complete
if
delta_token_ids
and
self
.
tool_call_end_token_id
not
in
delta_token_ids
:
# Count complete tool calls
complete_calls
=
len
(
self
.
tool_call_complete_regex
.
findall
(
current_text
)
)
self
.
_emit_delta
(
delta
)
def
_end_element
(
self
,
name
:
str
):
"""Handle XML end element events"""
# If we have completed tool calls and populated prev_tool_call_arr
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta for finish_reason processing
return
DeltaMessage
(
content
=
""
)
return
None
if
name
==
"root"
:
return
# Process all available content
accumulated_deltas
:
list
[
DeltaMessage
]
=
[]
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if
(
name
.
startswith
(
"function"
)
or
name
==
"function"
or
name
==
"tool_call"
)
and
self
.
current_param_name
:
self
.
_auto_close_open_parameter_if_needed
()
while
self
.
_has_unprocessed_content
(
current_text
):
# Try to process next chunk (content or tool call)
delta
=
self
.
_process_next_chunk
(
current_text
)
if
(
name
.
startswith
(
"parameter"
)
or
name
==
"parameter"
)
and
self
.
current_param_name
:
# End current parameter
param_name
=
self
.
current_param_name
param_value
=
self
.
current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if
self
.
defer_current_parameter
:
raw_text
=
(
self
.
deferred_param_raw_value
if
self
.
deferred_param_raw_value
else
param_value
)
parsed_value
=
None
output_arguments
=
None
try
:
# If previously delayed trailing newline,
# add it back before parsing
if
self
.
should_emit_end_newline
:
raw_for_parse
=
raw_text
+
"
\n
"
if
delta
is
None
:
# Cannot proceed further, need more tokens
break
# Accumulate deltas
if
isinstance
(
delta
,
list
):
accumulated_deltas
.
extend
(
delta
)
else
:
raw_for_parse
=
raw_text
parsed_value
=
ast
.
literal_eval
(
raw_for_parse
)
output_arguments
=
json
.
dumps
(
parsed_value
,
ensure_ascii
=
False
)
except
Exception
:
# Fallback: output as string as-is
output_arguments
=
json
.
dumps
(
raw_text
,
ensure_ascii
=
False
)
parsed_value
=
raw_text
accumulated_deltas
.
append
(
delta
)
# Handle EOS: flush any remaining incomplete tool calls as content
if
has_eos
:
remaining_delta
=
self
.
_flush_remaining_content
(
current_text
)
if
remaining_delta
:
accumulated_deltas
.
append
(
remaining_delta
)
# If no remaining content but we have tool calls, return empty delta
elif
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
accumulated_deltas
.
append
(
DeltaMessage
(
content
=
""
))
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
output_arguments
),
)
]
)
self
.
_emit_delta
(
delta
)
# Return results
return
self
.
_format_delta_result
(
accumulated_deltas
)
# Clean up and store
self
.
should_emit_end_newline
=
False
self
.
parameters
[
param_name
]
=
parsed_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
return
def
_has_unprocessed_content
(
self
,
current_text
:
str
)
->
bool
:
"""Check if there's unprocessed content in the buffer."""
return
self
.
_processed_length
<
len
(
current_text
)
param_type
=
self
.
_get_param_type
(
param_name
)
def
_process_next_chunk
(
self
,
current_text
:
str
)
->
DeltaMessage
|
list
[
DeltaMessage
]
|
None
:
"""Process next chunk: either regular content or a complete tool call.
# convert complete parameter value by param_type
c
onverted_value
=
self
.
_convert_param_value
(
param_value
,
param_type
)
Args:
c
urrent_text: Current accumulated text
# Decide whether to add end quote based on parameter type
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# For empty string parameters, need special handling
if
not
param_value
and
not
self
.
start_quote_emitted
:
# No start quote output,
# directly output complete empty string
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'""'
),
)
]
)
self
.
_emit_delta
(
delta
)
else
:
# Non-empty parameter value, output end quote
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
should_emit_end_newline
=
False
# Store converted value
self
.
parameters
[
param_name
]
=
converted_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
elif
name
.
startswith
(
"function"
)
or
name
==
"function"
:
# if there are parameters, close JSON object
if
self
.
parameters
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"}"
),
)
]
)
self
.
_emit_delta
(
delta
)
# return empty object
else
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"{}"
),
)
]
Returns:
- DeltaMessage or list of DeltaMessage if processed successfully
- None if cannot proceed (need more tokens)
"""
# Find next tool call start
tool_start_idx
=
self
.
_find_tool_call_start
(
current_text
,
self
.
_processed_length
)
self
.
_emit_delta
(
delta
)
self
.
current_function_open
=
False
self
.
current_function_name
=
(
None
# Clear function name to prevent duplicate closing
# Case 1: No tool call found - return remaining content
if
tool_start_idx
==
-
1
:
return
self
.
_process_content
(
current_text
,
self
.
_processed_length
,
len
(
current_text
)
)
elif
name
==
"tool_call"
:
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if
self
.
current_function_open
:
# If there are still unclosed parameters, close them first
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# Close function, ensure output '}' or '{}'
self
.
_end_element
(
"function"
)
# Final Delta
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
# Case 2: Content before tool call
if
tool_start_idx
>
self
.
_processed_length
:
return
self
.
_process_content
(
current_text
,
self
.
_processed_length
,
tool_start_idx
)
]
# Case 3: Tool call at current position
# Find end of the first complete tool call
tool_end_idx
=
self
.
_find_first_complete_tool_call_end
(
current_text
,
tool_start_idx
)
self
.
_emit_delta
(
delta
)
# Check if there's text content to output (between tool_calls)
if
self
.
text_content_buffer
.
strip
():
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
if
tool_end_idx
==
-
1
:
# Tool call incomplete, wait for more tokens
return
None
self
.
_reset_xml_parser_after_tool_call
()
# Process complete tool call
return
self
.
_process_complete_tool_calls
(
current_text
,
tool_start_idx
,
tool_end_idx
)
def
setup_parser
(
self
):
"""Set up XML parser event handlers"""
self
.
parser
.
buffer_text
=
True
self
.
parser
.
StartElementHandler
=
self
.
_start_element
self
.
parser
.
EndElementHandler
=
self
.
_end_element
self
.
parser
.
CharacterDataHandler
=
self
.
_char_data
def
_process_content
(
self
,
current_text
:
str
,
start_pos
:
int
,
end_pos
:
int
)
->
DeltaMessage
|
None
:
"""Process regular content (non-tool-call text).
def
set_tools
(
self
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
):
"""Set tool configuration information"""
self
.
tools
=
tools
Args:
current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
def
_extract_function_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract function name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
Returns:
DeltaMessage with content if non-empty
"""
if
start_pos
>=
end_pos
:
return
None
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"function"
:
return
parts
[
1
]
content
=
current_text
[
start_pos
:
end_pos
]
# Check if we're between tool calls - skip whitespace
if
start_pos
>
0
:
# Check if text before start_pos ends with </tool_call>
text_before
=
current_text
[:
start_pos
]
if
(
text_before
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
content
.
strip
()
==
""
):
# We just ended a tool call, skip whitespace between tool calls
self
.
_processed_length
=
end_pos
return
None
def
_extract_parameter_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract parameter name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"parameter"
:
return
parts
[
1
]
# Return content if non-empty
if
content
:
self
.
_processed_length
=
end_pos
return
DeltaMessage
(
content
=
content
)
# Mark as processed even if empty
self
.
_processed_length
=
end_pos
return
None
def
_get_param_type
(
self
,
param_name
:
str
)
->
str
:
"""Get parameter type based on tool configuration, defaults to string
def
_flush_remaining_content
(
self
,
current_text
:
str
)
->
DeltaMessage
|
None
:
"""Flush any remaining unprocessed content as regular content.
Args:
param_name: Parameter name
current_text: Current accumulated text
Returns:
Parameter type
Used when EOS token is encountered to handle incomplete tool calls.
"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
"string"
if
not
self
.
_has_unprocessed_content
(
current_text
)
:
return
None
for
tool
in
self
.
tools
:
if
not
hasattr
(
tool
,
"type"
)
or
not
(
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
):
continue
if
(
tool
.
type
==
"function"
and
tool
.
function
.
name
==
self
.
current_function_name
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
"string"
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
properties
=
params
[
"properties"
]
if
param_name
in
properties
and
isinstance
(
properties
[
param_name
],
dict
):
return
self
.
repair_param_type
(
str
(
properties
[
param_name
].
get
(
"type"
,
"string"
))
)
elif
isinstance
(
params
,
dict
)
and
param_name
in
params
:
param_config
=
params
[
param_name
]
if
isinstance
(
param_config
,
dict
):
return
self
.
repair_param_type
(
str
(
param_config
.
get
(
"type"
,
"string"
))
)
break
return
"string"
remaining
=
current_text
[
self
.
_processed_length
:]
if
remaining
:
self
.
_processed_length
=
len
(
current_text
)
return
DeltaMessage
(
content
=
remaining
)
def
repair_param_type
(
self
,
param_type
:
str
)
->
str
:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
self
.
_processed_length
=
len
(
current_text
)
return
None
Returns:
Repaired parameter type
"""
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
or
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
or
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
)
or
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]
or
(
param_type
in
[
"object"
,
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
):
return
param_type
else
:
return
"string"
def
_format_delta_result
(
self
,
deltas
:
list
[
DeltaMessage
])
->
DeltaMessage
|
None
:
"""Format delta result for return.
Merges all deltas into a single DeltaMessage.
def
_convert_param_value
(
self
,
param_value
:
str
,
param_type
:
str
)
->
Any
:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
deltas: List of delta messages
Returns:
Converted value
- None if empty
- Single merged DeltaMessage with all content and tool_calls
"""
if
param_value
.
lower
()
==
"null"
:
if
not
deltas
:
return
None
param_type
=
param_type
.
strip
().
lower
()
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
):
try
:
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not an integer, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
:
float
=
float
(
param_value
)
return
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not a float, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
param_value
=
param_value
.
lower
()
return
param_value
==
"true"
else
:
return
param_value
if
len
(
deltas
)
==
1
:
return
deltas
[
0
]
def
_convert_for_json_streaming
(
self
,
converted_value
:
Any
,
param_type
:
str
)
->
str
:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
# Merge multiple deltas into one
merged_content_parts
=
[]
merged_tool_calls
=
[]
Returns:
Converted string for streaming output
"""
# Check if value is empty, but exclude numeric 0
if
converted_value
is
None
or
converted_value
==
""
:
return
""
for
delta
in
deltas
:
if
delta
.
content
:
merged_content_parts
.
append
(
delta
.
content
)
if
delta
.
tool_calls
:
merged_tool_calls
.
extend
(
delta
.
tool_calls
)
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# String type, remove double quotes
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)[
1
:
-
1
]
else
:
# Non-string type, return complete JSON string
if
not
isinstance
(
converted_value
,
str
):
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
else
:
return
converted_value
# Create merged DeltaMessage
merged_content
=
""
.
join
(
merged_content_parts
)
if
merged_content_parts
else
None
def
_reset_xml_parser_after_tool_call
(
self
):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# Build kwargs - only include tool_calls if non-empty
kwargs
:
dict
[
str
,
Any
]
=
{
"content"
:
merged_content
}
if
merged_tool_calls
:
kwargs
[
"tool_calls"
]
=
merged_tool_calls
# recreate XML parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
# Reset current tool_call state
if
self
.
current_call_id
:
self
.
last_completed_call_id
=
self
.
current_call_id
self
.
current_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
text_content_buffer
=
""
# Reset preprocessing and deferred parsing state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
@
ToolParserManager
.
register_module
(
"step3p5"
)
class
Step3p5ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
self
.
parser
=
StreamingXMLToolCallParser
()
return
DeltaMessage
(
**
kwargs
)
# Add missing attributes for compatibility with serving_chat.py
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
def
_process_complete_tool_calls
(
self
,
current_text
:
str
,
start_pos
:
int
,
end_pos
:
int
)
->
list
[
DeltaMessage
]
|
None
:
"""Process complete tool calls and convert to delta sequence.
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
Args:
current_text: Current accumulated text
start_pos: Start position (should be at <tool_call>)
end_pos: End position (after </tool_call>)
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new extraction
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
result
=
self
.
parser
.
parse_single_streaming_chunks
(
model_output
)
if
not
result
.
tool_calls
:
return
ExtractedToolCallInformation
(
tool_calls
=
[],
tools_called
=
False
,
content
=
result
.
content
,
)
else
:
tool_calls
=
[]
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
and
tool_call
.
function
.
name
:
tool_calls
.
append
(
ToolCall
(
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
FunctionCall
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
)
)
Returns:
List of DeltaMessage if successful, None otherwise
"""
try
:
# Extract text segment containing complete tool call(s)
text_to_parse
=
current_text
[
start_pos
:
end_pos
]
# Update tool call tracking arrays for compatibility
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_tool_call_arr
)
-
1
# Parse using non-streaming method
result
=
self
.
extract_tool_calls_basic
(
text_to_parse
,
self
.
streaming_request
)
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Case 1: Successfully parsed tool calls
if
result
.
tools_called
and
result
.
tool_calls
:
# Note: Due to _find_first_complete_tool_call_end, we typically
# process only one tool call at a time
# but we can also process multiple tool calls below
deltas
=
self
.
_build_tool_call_deltas
(
result
.
tool_calls
,
text_to_parse
)
self
.
_update_state_after_tool_calls
(
result
.
tool_calls
,
end_pos
)
return
deltas
if
deltas
else
None
# Update tool call information
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
=
(
tool_call
.
function
.
arguments
)
# Case 2: Parsing failed - treat as regular content
self
.
_processed_length
=
end_pos
return
[
DeltaMessage
(
content
=
text_to_parse
)]
# Update streamed arguments
if
tool_call
.
function
.
arguments
:
self
.
streamed_args_for_tool
[
tool_index
]
=
(
tool_call
.
function
.
arguments
)
except
Exception
as
e
:
# Exception during parsing - treat as content
logger
.
debug
(
"Failed to parse tool calls: %s, treating as content"
,
e
)
self
.
_processed_length
=
end_pos
failed_text
=
current_text
[
start_pos
:
end_pos
]
return
[
DeltaMessage
(
content
=
failed_text
)]
if
failed_text
else
None
return
ExtractedToolCallInformation
(
tool_calls
=
tool_calls
,
tools_called
=
len
(
tool_calls
)
>
0
,
content
=
result
.
content
,
)
def
_build_tool_call_deltas
(
self
,
tool_calls
:
list
[
ToolCall
],
parsed_text
:
str
)
->
list
[
DeltaMessage
]:
"""Build delta messages from parsed tool calls with interleaved content.
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
if
not
previous_text
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new streaming session
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if
not
delta_text
and
delta_token_ids
:
open_calls
=
current_text
.
count
(
self
.
parser
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
parser
.
tool_call_end_token
)
if
(
open_calls
==
0
and
self
.
parser
.
tool_call_index
>
0
or
not
self
.
parser
.
tool_call_index
and
current_text
):
return
DeltaMessage
(
content
=
""
)
return
None
Args:
tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
# Parse the delta text and get the result
result
=
self
.
parser
.
parse_single_streaming_chunks
(
delta_text
)
# Update tool call tracking arrays based on incremental parsing results
if
result
and
result
.
tool_calls
:
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
:
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_tool_call_arr
)
-
1
Returns:
List of DeltaMessage with tool calls and content interleaved
"""
# Extract content segments between tool calls
content_segments
=
self
.
_extract_content_between_tool_calls_list
(
parsed_text
)
# Convert all tool calls to DeltaToolCall list
delta_tool_calls
=
self
.
_convert_tool_calls_to_deltas
(
tool_calls
,
self
.
_tool_call_index
)
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Merge all content segments into a single string
merged_content
=
""
.
join
(
content_segments
)
# Update tool name if provided
if
tool_call
.
function
.
name
:
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
# Return a single DeltaMessage with all tool calls and content
# Build kwargs - only include non-empty fields
kwargs
:
dict
[
str
,
Any
]
=
{}
if
merged_content
:
kwargs
[
"content"
]
=
merged_content
if
delta_tool_calls
:
kwargs
[
"tool_calls"
]
=
delta_tool_calls
# Update arguments incrementally
if
tool_call
.
function
.
arguments
is
not
None
:
# Concatenate the incremental arguments
# to the existing streamed arguments
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
+=
(
tool_call
.
function
.
arguments
)
self
.
streamed_args_for_tool
[
tool_index
]
+=
(
tool_call
.
function
.
arguments
)
return
result
# Only return DeltaMessage if we have content or tool_calls
if
kwargs
:
return
[
DeltaMessage
(
**
kwargs
)]
else
:
return
[]
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Skip the remaining_call calculation in serving_chat
def
_update_state_after_tool_calls
(
self
,
tool_calls
:
list
[
ToolCall
],
end_pos
:
int
)
->
None
:
"""Update internal state after processing tool calls.
Args:
tool_calls: List of processed tool calls
end_pos: End position in buffer
"""
return
False
# Update processed position
self
.
_processed_length
=
end_pos
# Update tool call index
self
.
_tool_call_index
+=
len
(
tool_calls
)
# Update prev_tool_call_arr for finish_reason
self
.
_update_prev_tool_call_state
(
tool_calls
)
\ No newline at end of file
vllm/v1/core/kv_cache_utils.py
View file @
fcc9c9ea
...
...
@@ -7,6 +7,7 @@ import os
from
collections
import
defaultdict
from
collections.abc
import
Callable
,
Iterable
,
Iterator
,
Sequence
from
dataclasses
import
dataclass
,
replace
from
functools
import
partial
from
typing
import
Any
,
NewType
,
TypeAlias
,
overload
from
vllm
import
envs
...
...
@@ -947,6 +948,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def
_get_kv_cache_groups_uniform_page_size
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
)
->
list
[
KVCacheGroupSpec
]:
"""
...
...
@@ -1007,6 +1009,7 @@ def _get_kv_cache_groups_uniform_page_size(
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
...
...
@@ -1030,9 +1033,9 @@ def _get_kv_cache_groups_uniform_page_size(
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
min_num_layers
=
min
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
min_num_layers
=
min
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
#12
group_size
=
min_num_layers
max_num_layers
=
max
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
max_num_layers
=
max
([
len
(
layers
)
for
layers
in
same_type_layers
.
values
()])
#36
if
max_num_layers
<
min_num_layers
*
1.25
:
# If the number of layers is not much larger than the minimum number of layers,
# use the maximum number of layers as the group size to avoid too many padding
...
...
@@ -1050,6 +1053,15 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers
/
len
(
layers
)
*
100
,
)
num_groups
=
cdiv
(
len
(
layers
),
group_size
)
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
enable_multi_layers_mtp
):
for
i
in
range
(
0
,
len
(
layers
),
group_size
):
grouped_layers
.
append
(
layers
[
i
:
i
+
group_size
])
else
:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
...
...
@@ -1120,7 +1132,6 @@ def get_kv_cache_config_from_groups(
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size
=
max
(
len
(
group
.
layer_names
)
for
group
in
kv_cache_groups
)
page_size
=
get_uniform_page_size
(
[
group
.
kv_cache_spec
for
group
in
kv_cache_groups
]
)
...
...
@@ -1247,8 +1258,10 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return
_get_kv_cache_groups_uniform_page_size
(
kv_cache_spec
)
# return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
return
_get_kv_cache_groups_uniform_page_size
(
vllm_config
=
vllm_config
,
kv_cache_spec
=
kv_cache_spec
)
def
generate_scheduler_kv_cache_config
(
kv_cache_configs
:
list
[
KVCacheConfig
],
...
...
@@ -1451,6 +1464,42 @@ def _auto_fit_max_model_len(
)
def
_project_kv_cache_groups_to_worker
(
global_kv_cache_groups
:
list
[
KVCacheGroupSpec
],
worker_spec
:
dict
[
str
,
KVCacheSpec
],
)
->
list
[
KVCacheGroupSpec
]:
"""
Projects global KV cache groups onto a single worker's assigned layers.
In pipeline parallelism, each worker only owns a subset of layers. This
function filters the global groups to include only layers present on the
given worker, adjusting UniformTypeKVCacheSpecs accordingly.
Args:
global_kv_cache_groups: The global KV cache groups for the whole model.
worker_spec: The KV cache spec of each layer on this worker.
Returns:
The projected KV cache groups containing only this worker's layers.
"""
projected_groups
:
list
[
KVCacheGroupSpec
]
=
[]
for
group
in
global_kv_cache_groups
:
worker_layer_names
=
[
layer_name
for
layer_name
in
group
.
layer_names
if
layer_name
in
worker_spec
]
group_spec
=
group
.
kv_cache_spec
if
worker_layer_names
and
isinstance
(
group_spec
,
UniformTypeKVCacheSpecs
):
group_spec
=
UniformTypeKVCacheSpecs
(
block_size
=
group_spec
.
block_size
,
kv_cache_specs
=
{
layer_name
:
group_spec
.
kv_cache_specs
[
layer_name
]
for
layer_name
in
worker_layer_names
},
)
projected_groups
.
append
(
KVCacheGroupSpec
(
worker_layer_names
,
group_spec
))
return
projected_groups
def
get_kv_cache_configs
(
vllm_config
:
VllmConfig
,
kv_cache_specs
:
list
[
dict
[
str
,
KVCacheSpec
]],
...
...
vllm/v1/cudagraph_dispatcher.py
View file @
fcc9c9ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Set
as
AbstractSet
from
dataclasses
import
replace
from
itertools
import
product
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.forward_context
import
BatchDescriptor
from
vllm.logger
import
init_logger
from
vllm.lora.utils
import
get_captured_lora_counts
logger
=
init_logger
(
__name__
)
...
...
@@ -57,6 +61,11 @@ class CudagraphDispatcher:
)
self
.
keys_initialized
=
False
self
.
specialize_lora_count
=
(
self
.
vllm_config
.
lora_config
.
specialize_active_lora
if
self
.
vllm_config
.
lora_config
is
not
None
else
False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self
.
cudagraph_mode
=
CUDAGraphMode
.
NONE
...
...
@@ -64,6 +73,9 @@ class CudagraphDispatcher:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size
=
self
.
compilation_config
.
max_cudagraph_capture_size
capture_sizes
=
self
.
compilation_config
.
cudagraph_capture_sizes
assert
capture_sizes
is
not
None
,
(
"Cudagraph capture sizes must be set when cudagraphs are enabled."
)
self
.
_bs_to_padded_graph_size
:
list
[
int
]
=
[
0
]
*
(
max_size
+
1
)
for
end
,
start
in
zip
(
capture_sizes
+
[
max_size
+
1
],
...
...
@@ -92,8 +104,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes."
)
def
_get_lora_cases
(
self
)
->
list
[
int
]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config
=
self
.
vllm_config
.
lora_config
if
lora_config
is
None
:
# No LoRA configured - single case with no LoRA
return
[
0
]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if
self
.
compilation_config
.
cudagraph_specialize_lora
:
captured_counts
=
get_captured_lora_counts
(
lora_config
.
max_loras
,
self
.
specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return
[
0
]
+
captured_counts
else
:
# No specialization: only capture graphs with LoRA active
return
[
lora_config
.
max_loras
+
1
]
def
_create_padded_batch_descriptor
(
self
,
num_tokens
:
int
,
uniform_decode
:
bool
,
has_lora
:
bool
self
,
num_tokens
:
int
,
uniform_decode
:
bool
,
has_lora
:
bool
,
num_active_loras
:
int
=
0
,
)
->
BatchDescriptor
:
max_num_seqs
=
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
uniform_decode_query_len
=
self
.
uniform_decode_query_len
...
...
@@ -111,6 +148,7 @@ class CudagraphDispatcher:
num_reqs
=
num_reqs
,
uniform
=
uniform_decode
,
has_lora
=
has_lora
,
num_active_loras
=
num_active_loras
,
)
def
add_cudagraph_key
(
...
...
@@ -143,18 +181,27 @@ class CudagraphDispatcher:
lora_cases
=
[
True
]
else
:
lora_cases
=
[
False
]
# Get LoRA cases to capture
# lora_cases = self._get_lora_cases()
self
.
captured_lora_counts
=
[
lora_count
for
lora_count
in
lora_cases
if
lora_count
]
# Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered.
if
cudagraph_mode
.
mixed_mode
()
!=
CUDAGraphMode
.
NONE
:
for
bs
,
has_lora
in
product
(
assert
self
.
compilation_config
.
cudagraph_capture_sizes
is
not
None
,
(
"Cudagraph capture sizes must be set when mixed mode is enabled."
)
for
bs
,
num_active_loras
in
product
(
self
.
compilation_config
.
cudagraph_capture_sizes
,
lora_cases
):
self
.
add_cudagraph_key
(
cudagraph_mode
.
mixed_mode
(),
self
.
_create_padded_batch_descriptor
(
bs
,
False
,
has
_lora
bs
,
False
,
num_active_loras
>
0
,
num_active
_lora
s
).
relax_for_mixed_batch_cudagraphs
(),
)
...
...
@@ -168,15 +215,20 @@ class CudagraphDispatcher:
uniform_decode_query_len
*
self
.
vllm_config
.
scheduler_config
.
max_num_seqs
)
assert
self
.
compilation_config
.
cudagraph_capture_sizes
is
not
None
,
(
"Cudagraph capture sizes must be set when full mode is enabled."
)
cudagraph_capture_sizes_for_decode
=
[
x
for
x
in
self
.
compilation_config
.
cudagraph_capture_sizes
if
x
<=
max_num_tokens
and
x
>=
uniform_decode_query_len
]
for
bs
,
has
_lora
in
product
(
cudagraph_capture_sizes_for_decode
,
lora_cases
):
for
bs
,
num_active
_lora
s
in
product
(
cudagraph_capture_sizes_for_decode
,
lora_cases
):
self
.
add_cudagraph_key
(
CUDAGraphMode
.
FULL
,
self
.
_create_padded_batch_descriptor
(
bs
,
True
,
has_lora
),
self
.
_create_padded_batch_descriptor
(
bs
,
True
,
num_active_loras
>
0
,
num_active_loras
),
)
self
.
keys_initialized
=
True
...
...
@@ -199,14 +251,19 @@ class CudagraphDispatcher:
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
valid_modes: Set of cudagraph modes that are allowed. None means
all modes are allowed.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
"""
# allowed_modes = valid_modes or CUDAGraphMode.valid_runtime_modes()
if
(
not
self
.
keys_initialized
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
or
num_tokens
>
self
.
compilation_config
.
max_cudagraph_capture_size
# or allowed_modes <= {CUDAGraphMode.NONE}
):
return
CUDAGraphMode
.
NONE
,
BatchDescriptor
(
num_tokens
)
...
...
vllm/v1/spec_decode/eagle.py
View file @
fcc9c9ea
...
...
@@ -3,6 +3,7 @@
import
ast
from
dataclasses
import
replace
from
importlib.util
import
find_spec
from
typing
import
Any
,
cast
import
numpy
as
np
import
torch
...
...
@@ -37,17 +38,21 @@ from vllm.v1.attention.backends.tree_attn import (
)
from
vllm.v1.attention.backends.triton_attn
import
TritonAttentionMetadata
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
UniformTypeKVCacheSpecs
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.sampler
import
_SAMPLING_EPS
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
,
SpecDecodeMetadata
from
vllm.v1.spec_decode.utils
import
(
extend_all_queries_by_N
,
compute_new_slot_mapping
,
copy_and_expand_eagle_inputs_kernel
,
eagle_prepare_inputs_padded_kernel
,
eagle_prepare_next_token_padded_kernel
,
)
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.utils
import
AttentionGroup
logger
=
init_logger
(
__name__
)
...
...
@@ -75,11 +80,33 @@ class SpecDecodeBaseProposer:
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
enable_multi_layers_mtp
=
self
.
speculative_config
.
enable_multi_layers_mtp
self
.
layer_num
=
1
# Unifying eagle, draft model, and parallel drafting support
self
.
parallel_drafting
:
bool
=
self
.
speculative_config
.
parallel_drafting
self
.
extra_slots_per_request
=
(
1
if
not
self
.
parallel_drafting
else
self
.
num_speculative_tokens
)
self
.
net_num_new_slots_per_request
=
self
.
extra_slots_per_request
-
(
1
if
self
.
pass_hidden_states_to_model
else
0
)
self
.
needs_extra_input_slots
=
self
.
net_num_new_slots_per_request
>
0
self
.
parallel_drafting_token_id
:
int
=
0
self
.
parallel_drafting_hidden_state_tensor
:
torch
.
Tensor
|
None
=
None
if
self
.
parallel_drafting
:
self
.
_init_parallel_drafting_params
()
self
.
use_local_argmax_reduction
:
bool
=
(
self
.
speculative_config
.
use_local_argmax_reduction
)
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
# self.max_num_tokens = (
# vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
# )
self
.
max_num_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
self
.
token_arange_np
=
np
.
arange
(
self
.
max_num_tokens
)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
...
...
@@ -93,6 +120,9 @@ class SpecDecodeBaseProposer:
vllm_config
.
model_config
)
self
.
draft_attn_groups
:
list
[
AttentionGroup
]
=
[]
self
.
kv_cache_gid
:
int
=
-
1
self
.
attn_metadata_builder
:
AttentionMetadataBuilder
|
None
=
None
self
.
draft_indexer_metadata_builder
:
AttentionMetadataBuilder
|
None
=
None
self
.
attn_layer_names
:
list
[
str
]
=
[]
...
...
@@ -116,6 +146,8 @@ class SpecDecodeBaseProposer:
# Use draft model's M-RoPE setting, not target model's
# Draft models may be text-only even if target is multimodal
self
.
uses_mrope
=
self
.
draft_model_config
.
uses_mrope
self
.
uses_xdrope_dim
=
self
.
vllm_config
.
model_config
.
uses_xdrope_dim
self
.
draft_uses_xdrope_dim
=
self
.
draft_model_config
.
uses_xdrope_dim
if
self
.
uses_mrope
:
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
...
...
@@ -139,6 +171,9 @@ class SpecDecodeBaseProposer:
(
self
.
max_num_tokens
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
)
# Will be set when we initialize the attention backend
# self.block_size: int = -1
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_num_slots_for_arange
=
max
(
max_batch_size
+
1
,
self
.
max_num_tokens
)
...
...
@@ -146,6 +181,26 @@ class SpecDecodeBaseProposer:
max_num_slots_for_arange
,
device
=
device
,
dtype
=
torch
.
int32
)
if
self
.
needs_extra_input_slots
:
self
.
_raise_if_padded_drafter_batch_disabled
()
self
.
_raise_if_multimodal
()
self
.
_raise_if_mrope
()
self
.
is_rejected_token_mask
:
torch
.
Tensor
|
None
=
None
self
.
is_masked_token_mask
:
torch
.
Tensor
|
None
=
None
if
self
.
needs_extra_input_slots
:
# For draft models and parallel drafting, we need to keep track of
# which tokens are rejected to update the slot mapping with padding slots.
self
.
is_rejected_token_mask
=
torch
.
zeros
(
(
self
.
max_num_tokens
,),
dtype
=
torch
.
bool
,
device
=
device
)
# For parallel drafting, we also need to keep track of which tokens
# are parallel-padding tokens used to sample at later positions.
# We populate this tensor even when using draft models for simplicity.
self
.
is_masked_token_mask
=
torch
.
zeros
(
(
self
.
max_num_tokens
,),
dtype
=
torch
.
bool
,
device
=
device
)
self
.
inputs_embeds
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
inputs_embeds_size
),
dtype
=
self
.
dtype
,
...
...
@@ -166,36 +221,6 @@ class SpecDecodeBaseProposer:
# Determine allowed attention backends once during initialization.
self
.
allowed_attn_types
:
tuple
|
None
=
None
# if current_platform.is_rocm():
# from vllm.v1.attention.backends.rocm_attn import RocmAttentionMetadata
# rocm_types = [
# TritonAttentionMetadata,
# RocmAttentionMetadata,
# ]
# # ROCM_AITER_FA is an optional backend
# if find_spec(
# AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
# ):
# from vllm.v1.attention.backends.rocm_aiter_fa import (
# AiterFlashAttentionMetadata,
# )
# rocm_types.append(AiterFlashAttentionMetadata)
# # TRITON_MLA backend support for MLA models (e.g., DeepSeek)
# from vllm.model_executor.layers.attention.mla_attention import (
# MLACommonMetadata,
# )
# rocm_types.append(MLACommonMetadata)
# # FlexAttention backend support
# from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadata
# rocm_types.append(FlexAttentionMetadata)
# self.allowed_attn_types = tuple(rocm_types)
# Parse the speculative token tree.
spec_token_tree
=
self
.
speculative_config
.
speculative_token_tree
...
...
@@ -251,7 +276,8 @@ class SpecDecodeBaseProposer:
self
.
_slot_mapping_buffer
[
num_actual
:
num_tokens
].
fill_
(
PADDING_SLOT_ID
)
view
=
self
.
_slot_mapping_buffer
[:
num_tokens
]
return
{
name
:
view
for
name
in
self
.
attn_layer_names
+
self
.
indexer_layer_names
}
# return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
return
{
name
:
view
for
name
in
self
.
_draft_attn_layer_names
}
def
initialize_cudagraph_keys
(
self
,
cudagraph_mode
:
CUDAGraphMode
)
->
None
:
"""Initialize cudagraph dispatcher keys for eagle.
...
...
@@ -270,6 +296,23 @@ class SpecDecodeBaseProposer:
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
def
adjust_input
(
self
,
batch_size
:
int
,
target_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Any
]:
return
(
target_token_ids
,
target_positions
,
target_hidden_states
,
common_attn_metadata
,
)
def
propose
(
self
,
# [num_tokens]
...
...
@@ -280,9 +323,10 @@ class SpecDecodeBaseProposer:
target_hidden_states
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
last_
token_indices
:
torch
.
Tensor
|
None
,
token_indices
_to_sample
:
torch
.
Tensor
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
...
...
@@ -298,12 +342,28 @@ class SpecDecodeBaseProposer:
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
num_tokens
,
last_token_indices
,
common_attn_metadata
=
(
(
target_token_ids
,
target_positions
,
target_hidden_states
,
common_attn_metadata
,
)
=
self
.
adjust_input
(
batch_size
=
batch_size
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
common_attn_metadata
=
common_attn_metadata
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
)
num_tokens
,
token_indices_to_sample
,
common_attn_metadata
=
(
self
.
set_inputs_first_pass
(
target_token_ids
=
target_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
last_token_indices
=
last_token_indices
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
...
...
@@ -355,6 +415,9 @@ class SpecDecodeBaseProposer:
# hidden dims. E.g. large target model and small draft model.
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
###### step3.5-mtp3新增
draft_token_ids_list
=
[]
for
spec_step_idx
in
range
(
self
.
layer_num
):
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
...
...
@@ -375,9 +438,13 @@ class SpecDecodeBaseProposer:
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"inputs_embeds"
:
inputs_embeds
,
}
if
self
.
pass_hidden_states_to_model
:
model_kwargs
[
"hidden_states"
]
=
self
.
hidden_states
[:
num_input_tokens
]
if
self
.
enable_multi_layers_mtp
:
model_kwargs
[
"spec_step_idx"
]
=
spec_step_idx
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
...
...
@@ -395,36 +462,65 @@ class SpecDecodeBaseProposer:
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
sample_hidden_states
=
last_hidden_states
[
token_indices_to_sample
]
if
self
.
enable_multi_layers_mtp
:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
spec_step_idx
=
spec_step_idx
)
else
:
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
if
en
v
s
.
VLLM_REJECT_SAMPLE_OPT
:
return
draft_token_ids
.
view
(
-
1
,
1
),
draft_prob
.
view
(
-
1
,
1
,
logits
.
shape
[
-
1
]
)
# Generate the remaining draft tok
ens.
draft_token_ids
_list
.
append
(
draft_token_ids
)
return
draft_token_ids
.
view
(
-
1
,
1
)
if
spec_step_idx
<
self
.
layer_num
-
1
:
prev_token_ids
=
self
.
input_ids
[:
num_tokens
].
clone
()
hidden_states
=
hidden_states
[:
num_tokens
]
next_token_ids
=
draft_token_ids_list
[
-
1
].
int
()
num_tokens
,
token_indices_to_sample
,
common_attn_metadata
=
(
self
.
set_inputs_first_pass
(
target_token_ids
=
prev_token_ids
,
next_token_ids
=
next_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
cad
=
common_attn_metadata
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
)
)
# Early exit if all draft tokens are generated in one pass
if
self
.
num_speculative_tokens
==
self
.
layer_num
or
self
.
parallel_drafting
:
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
##########################################################################
if
self
.
uses_mrope
:
positions
=
self
.
mrope_positions
[:,
last_
token_indices
]
positions
=
self
.
mrope_positions
[:,
token_indices
_to_sample
]
else
:
positions
=
self
.
positions
[
last_
token_indices
]
positions
=
self
.
positions
[
token_indices
_to_sample
]
if
self
.
method
in
(
"deepseek_mtp"
,
"ernie_mtp"
,
"longcat_flash_mtp"
,
"pangu_ultra_moe_mtp"
,
"step3p5_mtp"
,
# 新增
):
hidden_states
=
self
.
hidden_states
[
last_
token_indices
]
hidden_states
=
self
.
hidden_states
[
token_indices
_to_sample
]
else
:
hidden_states
=
hidden_states
[
last_
token_indices
]
hidden_states
=
hidden_states
[
token_indices
_to_sample
]
if
isinstance
(
attn_metadata
,
TreeAttentionMetadata
):
######
if
self
.
enable_multi_layers_mtp
:
raise
NotImplementedError
(
"Speculative Decoding with multi-layer MTP and tree attention "
"is not supported yet."
)
#####
# Draft using tree attention.
draft_token_ids_list
=
self
.
propose_tree
(
batch_size
=
batch_size
,
...
...
@@ -437,32 +533,22 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens]
return
torch
.
cat
(
draft_token_ids_list
,
dim
=
1
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
#
draft_token_ids = logits.argmax(dim=-1)
if
self
.
allowed_attn_types
is
not
None
and
not
isinstance
(
attn_metadata
,
self
.
allowed_attn_types
):
raise
ValueError
(
f
"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens >
1
: "
"decoding with num_speculative_tokens >
layer_num
: "
f
"
{
type
(
attn_metadata
)
}
. Supported types are: "
f
"
{
self
.
allowed_attn_types
}
"
)
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
batch_size_dp_padded
,
batch_size_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
batch_size
,
num_tokens_padded
=
batch_size
cudagraph_runtime_mode
,
input_batch_size
,
batch_size_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
batch_size
)
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
batch_size_dp_padded
)
input_batch_size
=
batch_desc
.
num_tokens
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
...
@@ -483,7 +569,7 @@ class SpecDecodeBaseProposer:
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_probs_list
=
[
draft_prob
]
for
token_index
in
range
(
self
.
num_speculative_tokens
-
1
):
for
token_index
in
range
(
self
.
num_speculative_tokens
-
self
.
layer_num
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
...
...
@@ -562,23 +648,9 @@ class SpecDecodeBaseProposer:
attn_metadata
=
attn_metadata_builder
.
build_for_drafting
(
# type: ignore
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
)
if
self
.
draft_indexer_metadata_builder
:
draft_indexer_metadata
=
(
self
.
draft_indexer_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
token_index
+
1
,
)
)
else
:
draft_indexer_metadata
=
None
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
for
layer_name
in
self
.
indexer_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
_set_positions
(
batch_size
,
clamped_positions
)
...
...
@@ -641,12 +713,17 @@ class SpecDecodeBaseProposer:
target_token_ids
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
last_token_indices
:
torch
.
Tensor
|
None
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
|
None
,
cad
:
CommonAttentionMetadata
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
,
)
->
tuple
[
int
,
torch
.
Tensor
,
CommonAttentionMetadata
]:
if
last_token_indices
is
None
:
last_token_indices
=
cad
.
query_start_loc
[
1
:]
-
1
if
not
self
.
needs_extra_input_slots
:
# Default EAGLE pathway: no reshaping of input tensors needed.
# Simply rotate the input ids and leave the positions unchanged,
# Inserting the next token ids at the last slot in each request.
if
token_indices_to_sample
is
None
:
token_indices_to_sample
=
cad
.
query_start_loc
[
1
:]
-
1
num_tokens
=
target_token_ids
.
shape
[
0
]
# Shift the input ids by one token.
...
...
@@ -654,12 +731,120 @@ class SpecDecodeBaseProposer:
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_
token_indices
]
=
next_token_ids
self
.
input_ids
[
token_indices
_to_sample
]
=
next_token_ids
# copy inputs to buffer for cudagraph
if
self
.
uses_xdrope_dim
>
0
and
self
.
draft_uses_xdrope_dim
==
0
:
target_positions
=
target_positions
[
0
]
self
.
_set_positions
(
num_tokens
,
target_positions
)
return
num_tokens
,
last_token_indices
,
cad
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
return
num_tokens
,
token_indices_to_sample
,
cad
else
:
assert
self
.
is_rejected_token_mask
is
not
None
assert
self
.
is_masked_token_mask
is
not
None
# 1.
# Call a custom triton kernel to copy input_ids and positions
# into the correct slots in the preallocated buffers self.input_ids,
# self.positions.
batch_size
=
cad
.
batch_size
()
# Since we might have to copy a lot of data for prefills, we select the
# block size based on the max query length and limit to max 256 slots/block.
max_num_tokens_per_request
=
(
cad
.
max_query_len
+
self
.
net_num_new_slots_per_request
)
BLOCK_SIZE_TOKENS
=
min
(
256
,
triton
.
next_power_of_2
(
max_num_tokens_per_request
)
)
num_blocks
=
(
max_num_tokens_per_request
+
BLOCK_SIZE_TOKENS
-
1
)
//
BLOCK_SIZE_TOKENS
total_num_input_tokens
=
target_token_ids
.
shape
[
0
]
total_num_output_tokens
=
total_num_input_tokens
+
(
self
.
net_num_new_slots_per_request
*
batch_size
)
token_indices_to_sample
=
torch
.
empty
(
batch_size
*
self
.
extra_slots_per_request
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
# Destination indices to write target_hidden_states into drafting buffer.
out_hidden_state_mapping
=
torch
.
empty
(
total_num_input_tokens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# Kernel grid: one program per request (row)
grid
=
(
batch_size
,
num_blocks
)
query_start_loc
=
cad
.
query_start_loc
query_end_loc
=
cad
.
query_start_loc
[
1
:]
-
1
if
num_rejected_tokens_gpu
is
not
None
:
query_end_loc
=
query_end_loc
-
num_rejected_tokens_gpu
copy_and_expand_eagle_inputs_kernel
[
grid
](
# (Padded) Inputs from the target model
target_token_ids_ptr
=
target_token_ids
,
target_positions_ptr
=
target_positions
,
next_token_ids_ptr
=
next_token_ids
,
# sampled tokens, one per request
# Outputs to the drafting buffers
out_input_ids_ptr
=
self
.
input_ids
,
out_positions_ptr
=
self
.
positions
,
# Doesn't support mrope for now
out_is_rejected_token_mask_ptr
=
self
.
is_rejected_token_mask
,
out_is_masked_token_mask_ptr
=
self
.
is_masked_token_mask
,
out_new_token_indices_ptr
=
token_indices_to_sample
,
out_hidden_state_mapping_ptr
=
out_hidden_state_mapping
,
# Input metadata
query_start_loc_ptr
=
query_start_loc
,
query_end_loc_ptr
=
query_end_loc
,
padding_token_id
=
0
,
parallel_drafting_token_id
=
self
.
parallel_drafting_token_id
,
# Sizing info
# Note that we can deduce batch_size for free from the grid size
total_input_tokens
=
total_num_input_tokens
,
num_padding_slots_per_request
=
self
.
extra_slots_per_request
,
shift_input_ids
=
self
.
pass_hidden_states_to_model
,
BLOCK_SIZE_TOKENS
=
BLOCK_SIZE_TOKENS
,
)
if
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
self
.
hidden_states
[
out_hidden_state_mapping
]
=
target_hidden_states
# Use torch.where to avoid DtoH sync from boolean indexing
mask
=
self
.
is_masked_token_mask
[:
total_num_output_tokens
]
torch
.
where
(
mask
.
unsqueeze
(
1
),
self
.
parallel_drafting_hidden_state_tensor
,
self
.
hidden_states
[:
total_num_output_tokens
],
out
=
self
.
hidden_states
[:
total_num_output_tokens
],
)
# 2.
# Recompute the slot mapping based on the new positions and
# rejection mask.
# Use the first draft attention group's kv_cache_spec for block_size
# (all draft layers share the same kv-cache group)
assert
len
(
self
.
draft_attn_groups
)
>
0
block_size
=
self
.
draft_attn_groups
[
0
].
kv_cache_spec
.
block_size
new_slot_mapping
=
compute_new_slot_mapping
(
cad
=
cad
,
new_positions
=
self
.
positions
[:
total_num_output_tokens
],
is_rejected_token_mask
=
self
.
is_rejected_token_mask
[
:
total_num_output_tokens
],
block_size
=
block_size
,
num_new_tokens
=
self
.
net_num_new_slots_per_request
,
max_model_len
=
self
.
max_model_len
,
)
# 3. Update the common attention metadata with the new (meta)data
new_cad
=
extend_all_queries_by_N
(
cad
,
N
=
self
.
net_num_new_slots_per_request
,
arange
=
self
.
arange
,
new_slot_mapping
=
new_slot_mapping
,
)
return
total_num_output_tokens
,
token_indices_to_sample
,
new_cad
def
model_returns_tuple
(
self
)
->
bool
:
return
self
.
method
not
in
(
"mtp"
,
"draft_model"
)
...
...
@@ -1096,10 +1281,28 @@ class SpecDecodeBaseProposer:
model
=
model
.
module
return
model
.
__class__
.
__name__
def
_get_model
(
self
)
->
nn
.
Module
:
"""
Default method to call get_model(). Can be overridden by subclasses which
need to customize model loading.
"""
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"eagle_head"
):
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
self
.
speculative_config
.
draft_model_config
,
# load_config=self.speculative_config.draft_load_config,
)
return
model
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
draft_model_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
).
keys
()
)
# FIXME: support hybrid kv for draft model
target_indexer_layer_names
=
set
(
...
...
@@ -1107,23 +1310,26 @@ class SpecDecodeBaseProposer:
self
.
vllm_config
,
DeepseekV32IndexerCache
).
keys
()
)
self
.
model
=
self
.
_get_model
()
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"eagle_head"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
draft_model_config
)
draft_attn_layer_names
=
(
# Find draft layers (attention layers added by draft model)
# all_attn_layers = get_layers_from_vllm_config(
# self.vllm_config,
# AttentionLayerBase, # type: ignore[type-abstract]
# )
# self._draft_attn_layer_names = (
# set(all_attn_layers.keys()) - target_attn_layer_names
# )
self
.
_draft_attn_layer_names
=
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
-
target_attn_layer_names
)
indexer_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
DeepseekV32IndexerCache
)
draft_indexer_layer_names
=
indexer_layers
.
keys
()
-
target_indexer_layer_names
self
.
attn_layer_names
=
list
(
draft_attn_layer_names
-
draft_indexer_layer_names
)
self
.
attn_layer_names
=
list
(
self
.
_
draft_attn_layer_names
-
draft_indexer_layer_names
)
self
.
indexer_layer_names
=
list
(
draft_indexer_layer_names
)
if
self
.
indexer_layer_names
:
...
...
@@ -1160,6 +1366,7 @@ class SpecDecodeBaseProposer:
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen3VLForConditionalGeneration"
,
"Qwen3VLMoeForConditionalGeneration"
,
"HunYuanVLForConditionalGeneration"
,
"GlmOcrForConditionalGeneration"
,
"Qwen3_5ForConditionalGeneration"
,
"Qwen3_5MoeForConditionalGeneration"
,
...
...
@@ -1177,12 +1384,34 @@ class SpecDecodeBaseProposer:
else
:
target_language_model
=
target_model
# share embed_tokens with the target model if needed
self
.
_maybe_share_embeddings
(
target_language_model
)
self
.
_maybe_share_lm_head
(
target_language_model
)
if
self
.
parallel_drafting
and
self
.
pass_hidden_states_to_model
:
assert
self
.
parallel_drafting_hidden_state_tensor
is
not
None
self
.
parallel_drafting_hidden_state_tensor
.
copy_
(
self
.
model
.
combine_hidden_states
(
self
.
model
.
mask_hidden
.
view
(
3
*
self
.
hidden_size
)
)
if
self
.
eagle3_use_aux_hidden_state
else
self
.
model
.
mask_hidden
.
view
(
self
.
hidden_size
)
)
def
_maybe_share_embeddings
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
Some draft models may not have their own embedding layers, and some may
have a duplicate copy of the target model's embedding layers. In these cases,
we share the target model's embedding layers with the draft model to save
memory.
"""
if
get_pp_group
().
world_size
==
1
:
if
hasattr
(
target_language_model
.
model
,
"embed_tokens"
):
target_embed_tokens
=
target_language_model
.
model
.
embed_tokens
elif
hasattr
(
target_language_model
.
model
,
"embedding"
):
target_embed_tokens
=
target_language_model
.
model
.
embedding
inner_model
=
getattr
(
target_language_model
,
"model"
,
None
)
if
inner_model
is
None
:
raise
AttributeError
(
"Target model does not have 'model' attribute"
)
if
hasattr
(
inner_model
,
"embed_tokens"
):
target_embed_tokens
=
inner_model
.
embed_tokens
elif
hasattr
(
inner_model
,
"embedding"
):
target_embed_tokens
=
inner_model
.
embedding
else
:
raise
AttributeError
(
"Target model does not have 'embed_tokens' or 'embedding' attribute"
...
...
@@ -1237,7 +1466,12 @@ class SpecDecodeBaseProposer:
" from the target model."
)
# share lm_head with the target model if needed
def
_maybe_share_lm_head
(
self
,
target_language_model
:
nn
.
Module
)
->
None
:
"""
Some draft models may not have their own LM head, and some may have a
duplicate copy of the target model's LM head. In these cases, we share
the target model's LM head with the draft model to save memory.
"""
share_lm_head
=
False
if
hasattr
(
self
.
model
,
"has_own_lm_head"
):
# EAGLE model
...
...
@@ -1299,6 +1533,32 @@ class SpecDecodeBaseProposer:
"Shared target model lm_head with MTP shared_head.head."
)
if
self
.
use_local_argmax_reduction
:
if
not
hasattr
(
self
.
model
,
"get_top_tokens"
):
raise
ValueError
(
"use_local_argmax_reduction is enabled but draft model "
f
"
{
self
.
model
.
__class__
.
__name__
}
does not implement "
"get_top_tokens()."
)
# Warn if draft model has vocab remapping, which forces fallback
# to the full-logits path (negating the optimization).
if
(
hasattr
(
self
.
model
,
"draft_id_to_target_id"
)
and
self
.
model
.
draft_id_to_target_id
is
not
None
):
logger
.
warning
(
"use_local_argmax_reduction is enabled but draft model "
"uses draft_id_to_target_id vocab remapping. The "
"optimization will be bypassed (falling back to full "
"logits gather + argmax)."
)
else
:
logger
.
info
(
"Using local argmax reduction for draft token generation "
"(communication: O(2*tp_size) vs O(vocab_size))."
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
...
...
@@ -1329,9 +1589,9 @@ class SpecDecodeBaseProposer:
# Make sure to use EAGLE's own buffer during cudagraph capture.
if
(
self
.
attn_layer_names
self
.
_draft_
attn_layer_names
and
slot_mappings
is
not
None
and
self
.
attn_layer_names
[
0
]
in
slot_mappings
and
next
(
iter
(
self
.
_draft_
attn_layer_names
))
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
...
...
@@ -1425,6 +1685,64 @@ class SpecDecodeBaseProposer:
==
1
),
"All drafting layers should belong to the same kv cache group"
# def initialize_attn_backend(
# self,
# kv_cache_config: KVCacheConfig,
# kernel_block_sizes: list[int] | None = None,
# ) -> None:
# """
# Initialize AttentionGroups for draft layers using kv_cache_config.
# Called from the model runner's initialize_metadata_builders.
# """
# all_attn_layers = get_layers_from_vllm_config(
# self.vllm_config,
# AttentionLayerBase, # type: ignore[type-abstract]
# )
# # Find which kv_cache_group the draft layers belong to
# self.validate_same_kv_cache_group(kv_cache_config)
# kv_cache_spec = None
# for gid, group in enumerate(kv_cache_config.kv_cache_groups):
# if self._draft_attn_layer_names & set(group.layer_names):
# self.kv_cache_gid = gid
# kv_cache_spec = group.kv_cache_spec
# break
# attention_groups: dict[tuple[str, str], AttentionGroup] = {}
# if kv_cache_spec is not None:
# for layer_name in self._draft_attn_layer_names:
# attn_backend = all_attn_layers[layer_name].get_attn_backend()
# backend_key = attn_backend.full_cls_name()
# if backend_key not in attention_groups:
# layer_kv_cache_spec = kv_cache_spec
# if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
# layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
# layer_name
# ]
# kernel_block_size = (
# kernel_block_sizes[self.kv_cache_gid]
# if kernel_block_sizes is not None
# and self.kv_cache_gid < len(kernel_block_sizes)
# else None
# )
# attn_group = AttentionGroup(
# backend=attn_backend,
# layer_names=[layer_name],
# kv_cache_spec=layer_kv_cache_spec,
# kv_cache_group_id=self.kv_cache_gid,
# )
# attn_group.create_metadata_builders(
# self.vllm_config,
# self.device,
# kernel_block_size=kernel_block_size,
# )
# attention_groups[backend_key] = attn_group
# else:
# attention_groups[backend_key].layer_names.append(layer_name)
# self.draft_attn_groups = list(attention_groups.values())
def
_pad_batch_across_dp
(
self
,
num_tokens_unpadded
:
int
,
...
...
@@ -1449,6 +1767,50 @@ class SpecDecodeBaseProposer:
return
num_tokens_dp_padded
,
num_toks_across_dp
def
_determine_batch_execution_and_padding
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
)
->
tuple
[
CUDAGraphMode
,
int
,
torch
.
Tensor
|
None
]:
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens
,
)
num_tokens_padded
=
batch_desc
.
num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch
,
num_tokens_across_dp
=
False
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
num_tokens_padded
=
num_tokens_padded
,
cudagraph_mode
=
cudagraph_mode
.
value
,
)
)
assert
not
should_ubatch
,
"DBO ubatching not implemented for EAGLE"
# Extract DP-synced values
if
num_tokens_across_dp
is
not
None
:
dp_rank
=
self
.
dp_rank
num_tokens_padded
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_padded
,
valid_modes
=
{
CUDAGraphMode
(
synced_cudagraph_mode
)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert
batch_desc
.
num_tokens
==
num_tokens_padded
num_tokens_across_dp
[
dp_rank
]
=
num_tokens_padded
return
cudagraph_mode
,
num_tokens_padded
,
num_tokens_across_dp
class
EagleProposer
(
SpecDecodeBaseProposer
):
def
__init__
(
self
,
...
...
vllm/v1/spec_decode/extract_hidden_states.py
0 → 100755
View file @
fcc9c9ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_transfer
import
has_kv_transfer_group
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
,
CommonAttentionMetadata
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
PADDING_SLOT_ID
=
-
1
class
ExtractHiddenStatesProposer
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
num_speculative_tokens
==
1
if
vllm_config
.
speculative_config
.
disable_padded_drafter_batch
:
raise
ValueError
(
"disable_padded_drafter_batch is not supported with "
"extract_hidden_states method"
)
self
.
vllm_config
=
vllm_config
self
.
device
=
device
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
# Model and attention layer tracking (initialized in load_model)
self
.
model
:
nn
.
Module
|
None
=
None
self
.
attn_layer_names
:
list
[
str
]
=
[]
self
.
attn_metadata_builder
:
AttentionMetadataBuilder
|
None
=
None
# Maximum number of tokens for buffers
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
self
.
hf_config
=
vllm_config
.
speculative_config
.
draft_model_config
.
hf_config
layer_ids
=
getattr
(
self
.
hf_config
,
"eagle_aux_hidden_state_layer_ids"
,
None
)
if
not
layer_ids
:
raise
ValueError
(
"eagle_aux_hidden_state_layer_ids must be set in the draft "
"model config for extract_hidden_states method"
)
self
.
num_hidden_states
=
len
(
layer_ids
)
self
.
hidden_size
=
vllm_config
.
model_config
.
get_hidden_size
()
self
.
hidden_states
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
self
.
num_hidden_states
,
self
.
hidden_size
),
dtype
=
self
.
dtype
,
device
=
device
,
)
self
.
cudagraph_dispatcher
=
CudagraphDispatcher
(
self
.
vllm_config
)
self
.
_slot_mapping_buffer
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
)
def
propose
(
self
,
sampled_token_ids
:
torch
.
Tensor
,
target_hidden_states
:
list
[
torch
.
Tensor
],
common_attn_metadata
:
CommonAttentionMetadata
,
scheduler_output
:
SchedulerOutput
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
KVConnectorOutput
|
None
]:
"""Propose draft tokens by calling the ExtractHiddenStatesModel model.
The ExtractHiddenStatesModel caches the hidden states in the KV cache
without performing actual attention computation. This allows us to
extract and store hidden states for later use (e.g., KV transfer).
This proposer doesn't actually perform speculation - it returns the
sampled tokens as "draft" tokens, ensuring they always verify (match).
The main purpose is to cache hidden states, not to speculate.
Args:
sampled_token_ids: Sampled token IDs from the target model
target_hidden_states: List of hidden state tensors from target model
(one per aux hidden state layer)
common_attn_metadata: Attention metadata
scheduler_output: Scheduler output for KV connector
slot_mappings: Slot mappings for KV cache (unused, provided for
interface compatibility)
Returns:
Tuple of:
- Draft tokens matching sampled tokens, shape [batch_size, 1]
- KV connector output (if KV transfer is active), else None
"""
assert
self
.
model
is
not
None
and
isinstance
(
target_hidden_states
,
list
)
# target_hidden_states is a list of tensors (one per layer)
# Each tensor has shape [num_tokens, hidden_size]
# Stack to shape: [num_tokens, num_hidden_states, hidden_size]
stacked_hidden_states
=
torch
.
stack
(
target_hidden_states
,
dim
=
1
)
num_tokens
=
stacked_hidden_states
.
shape
[
0
]
# Copy hidden states to buffer
self
.
hidden_states
[:
num_tokens
]
=
stacked_hidden_states
assert
self
.
attn_metadata_builder
is
not
None
attn_metadata
=
self
.
attn_metadata_builder
.
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
)
# We assume all cache-only layers belong to the same KV cache group,
# thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
with
(
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
),
(
KVConnectorModelRunnerMixin
.
_get_kv_connector_output
(
scheduler_output
)
if
has_kv_transfer_group
()
else
nullcontext
()
)
as
kv_connector_output
,
):
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
)
# Return the sampled tokens as "draft" tokens
# Shape: [batch_size, 1] to match num_speculative_tokens=1
return
sampled_token_ids
.
unsqueeze
(
-
1
),
kv_connector_output
def
_get_slot_mapping
(
self
,
num_tokens
:
int
,
slot_mapping
:
torch
.
Tensor
|
None
=
None
,
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Return slot_mapping dict for cache-only attention layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if
slot_mapping
is
not
None
:
num_actual
=
slot_mapping
.
shape
[
0
]
self
.
_slot_mapping_buffer
[:
num_actual
].
copy_
(
slot_mapping
)
if
num_tokens
>
num_actual
:
self
.
_slot_mapping_buffer
[
num_actual
:
num_tokens
].
fill_
(
PADDING_SLOT_ID
)
view
=
self
.
_slot_mapping_buffer
[:
num_tokens
]
return
{
name
:
view
for
name
in
self
.
attn_layer_names
}
def
_determine_batch_execution_and_padding
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
)
->
tuple
[
CUDAGraphMode
,
int
,
torch
.
Tensor
|
None
]:
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens
,
valid_modes
=
({
CUDAGraphMode
.
NONE
}
if
not
use_cudagraphs
else
None
),
)
num_tokens_padded
=
batch_desc
.
num_tokens
# Extra coordination when running data-parallel since we need to
# coordinate across ranks
# TODO(Flechman): support DBO ubatching
should_ubatch
,
num_tokens_across_dp
=
False
,
None
if
self
.
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
should_ubatch
,
num_tokens_across_dp
,
synced_cudagraph_mode
=
(
coordinate_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
parallel_config
=
self
.
vllm_config
.
parallel_config
,
allow_microbatching
=
False
,
num_tokens_padded
=
num_tokens_padded
,
cudagraph_mode
=
cudagraph_mode
.
value
,
)
)
assert
not
should_ubatch
,
(
"DBO ubatching not implemented for extract_hidden_states"
)
# Extract DP-synced values
if
num_tokens_across_dp
is
not
None
:
dp_rank
=
self
.
dp_rank
num_tokens_padded
=
int
(
num_tokens_across_dp
[
dp_rank
].
item
())
# Re-dispatch with DP padding so we have the correct
# batch_descriptor
cudagraph_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_padded
,
valid_modes
=
{
CUDAGraphMode
(
synced_cudagraph_mode
)},
)
# Assert to make sure the agreed upon token count is correct
# otherwise num_tokens_across_dp will no-longer be valid
assert
batch_desc
.
num_tokens
==
num_tokens_padded
num_tokens_across_dp
[
dp_rank
]
=
num_tokens_padded
return
cudagraph_mode
,
num_tokens_padded
,
num_tokens_across_dp
def
initialize_cudagraph_keys
(
self
,
cudagraph_mode
:
CUDAGraphMode
)
->
None
:
"""Initialize cudagraph dispatcher keys.
Only supports PIECEWISE cudagraphs (via mixed_mode).
Should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
assert
self
.
vllm_config
.
speculative_config
is
not
None
if
(
not
self
.
vllm_config
.
speculative_config
.
enforce_eager
and
cudagraph_mode
.
mixed_mode
()
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]
):
proposer_cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
else
:
proposer_cudagraph_mode
=
CUDAGraphMode
.
NONE
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
proposer_cudagraph_mode
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
assert
self
.
model
is
not
None
,
"Model must be initialized before dummy_run"
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
,
use_cudagraphs
=
use_cudagraphs
)
)
if
num_tokens_across_dp
is
not
None
:
num_tokens_across_dp
[
self
.
dp_rank
]
=
num_input_tokens
# Use our own slot mapping buffer during cudagraph capture.
if
(
self
.
attn_layer_names
and
slot_mappings
is
not
None
and
self
.
attn_layer_names
[
0
]
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
self
.
model
(
hidden_states
=
self
.
hidden_states
[:
num_input_tokens
],
)
def
_build_attn_metadata_builder
(
self
,
draft_attn_layers
:
dict
[
str
,
AttentionLayerBase
]
)
->
AttentionMetadataBuilder
:
"""Build the attention metadata builder from draft attention layers."""
if
not
draft_attn_layers
:
raise
ValueError
(
"No attention layers found for ExtractHiddenStatesModel"
)
layer
=
next
(
iter
(
draft_attn_layers
.
values
()))
attn_backend
=
layer
.
get_attn_backend
()
return
attn_backend
.
get_builder_cls
()(
layer
.
get_kv_cache_spec
(
self
.
vllm_config
),
self
.
attn_layer_names
,
self
.
vllm_config
,
self
.
device
,
)
def
prepare_next_token_ids_padded
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampled_token_ids
:
torch
.
Tensor
,
requests
:
dict
[
str
,
CachedRequestState
],
gpu_input_batch
:
InputBatch
,
discard_request_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Prepare next token IDs for speculative decoding.
Since num_speculative_tokens == 1, sampled_token_ids has shape
(batch_size, 1). For each request we either use the sampled token
(if valid and not discarded) or a backup token from the request state.
"""
num_reqs
=
gpu_input_batch
.
num_reqs
device
=
sampled_token_ids
.
device
# Compute backup tokens for discarded / invalid requests
backup_tokens_gpu
=
torch
.
tensor
(
[
requests
[
gpu_input_batch
.
req_ids
[
i
]].
get_token_id
(
common_attn_metadata
.
seq_lens_cpu
[
i
].
item
()
)
for
i
in
range
(
num_reqs
)
],
dtype
=
torch
.
int32
,
device
=
device
,
)
assert
discard_request_mask
.
dtype
==
torch
.
bool
# With num_speculative_tokens == 1, there is exactly one token
sampled
=
sampled_token_ids
[:,
0
]
is_valid
=
(
sampled
>=
0
)
&
(
sampled
<
gpu_input_batch
.
vocab_size
)
valid_sampled_tokens_count
=
is_valid
.
to
(
torch
.
int32
)
use_sampled
=
is_valid
&
~
discard_request_mask
[:
num_reqs
]
next_token_ids
=
torch
.
where
(
use_sampled
,
sampled
.
to
(
torch
.
int32
),
backup_tokens_gpu
)
return
next_token_ids
,
valid_sampled_tokens_count
def
load_model
(
self
,
target_model
:
nn
.
Module
)
->
None
:
"""Load the ExtractHiddenStatesModel model.
This method instantiates the ExtractHiddenStatesModel model which is used
to cache hidden states during speculative decoding. The model uses
cache-only attention (no computation, just caching KV states).
Args:
target_model: The target model (passed for compatibility with
EagleProposer interface, but not used here)
"""
# Get the target model's attention layers before loading draft model
target_attn_layer_names
=
set
(
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
).
keys
()
# type: ignore[type-abstract]
)
assert
self
.
vllm_config
.
speculative_config
is
not
None
draft_model_config
=
self
.
vllm_config
.
speculative_config
.
draft_model_config
from
vllm.compilation.backends
import
set_model_tag
with
set_model_tag
(
"extract_hidden_states"
):
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
,
model_config
=
draft_model_config
)
# Identify draft model's attention layers (difference from target)
all_attn_layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
AttentionLayerBase
,
# type: ignore[type-abstract]
)
draft_attn_layers
=
{
name
:
layer
for
name
,
layer
in
all_attn_layers
.
items
()
if
name
not
in
target_attn_layer_names
}
self
.
attn_layer_names
=
list
(
draft_attn_layers
.
keys
())
assert
len
(
draft_attn_layers
)
==
1
,
(
"ExtractHiddenStatesModel should have exactly one "
f
"attention layer, found
{
len
(
draft_attn_layers
)
}
"
)
self
.
attn_metadata_builder
=
self
.
_build_attn_metadata_builder
(
draft_attn_layers
)
def
validate_same_kv_cache_group
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""Validate all drafting layers belong to the same KV cache group.
With exactly one attention layer (asserted in load_model), this is
trivially satisfied.
"""
assert
len
(
self
.
attn_layer_names
)
==
1
vllm/v1/spec_decode/metadata.py
View file @
fcc9c9ea
...
...
@@ -67,3 +67,41 @@ class SpecDecodeMetadata:
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
)
@
dataclass
class
MultiLayerEagleMetadata
:
# [batch_size]
cached_len
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_token_ids
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num, hidden_size]
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_positions
:
torch
.
Tensor
|
None
=
None
@
classmethod
def
make_dummy
(
cls
,
layer_num
:
int
,
hidden_size
:
int
,
device
:
torch
.
device
,
)
->
"MultiLayerEagleMetadata"
:
cached_len
=
torch
.
zeros
((
1
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_token_ids
=
torch
.
zeros
((
1
,
layer_num
),
dtype
=
torch
.
int32
,
device
=
device
)
cached_hidden_states
=
torch
.
zeros
(
(
1
,
layer_num
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
device
)
cached_slot_mappings
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_positions
=
torch
.
zeros
((
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
return
cls
(
cached_len
=
cached_len
,
cached_token_ids
=
cached_token_ids
,
cached_hidden_states
=
cached_hidden_states
,
cached_slot_mappings
=
cached_slot_mappings
,
cached_positions
=
cached_positions
,
)
\ No newline at end of file
vllm/v1/spec_decode/multi_layer_eagle.py
0 → 100755
View file @
fcc9c9ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
(
CommonAttentionMetadata
,
)
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
logger
=
init_logger
(
__name__
)
BLOCK_HIDDEN
=
128
BLOCK_TOKENS
=
128
class
MultiLayerEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
layer_num
:
int
=
getattr
(
self
.
speculative_config
.
draft_model_config
.
hf_text_config
,
"n_predict"
,
0
)
self
.
num_speculative_tokens
:
int
=
(
self
.
speculative_config
.
num_speculative_tokens
)
def
adjust_input
(
self
,
batch_size
:
int
,
target_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Any
]:
assert
multi_layer_eagle_metadata
is
not
None
if
token_indices_to_sample
is
None
:
token_indices_to_sample
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
MAX_SHIFT
=
self
.
layer_num
assert
MAX_SHIFT
>
0
prev_token_ids
=
target_token_ids
.
clone
()
prev_positions
=
target_positions
.
clone
()
prev_hidden_states
=
target_hidden_states
.
clone
()
slot_mapping
=
common_attn_metadata
.
slot_mapping
start_token_indices
=
common_attn_metadata
.
query_start_loc
[:
-
1
]
end_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
pos_for_shift
=
(
target_positions
[
0
]
if
target_positions
.
dim
()
==
2
else
target_positions
)
start_token_pos
=
pos_for_shift
[
start_token_indices
]
shift
=
torch
.
minimum
(
end_token_indices
-
token_indices_to_sample
,
start_token_pos
,
)
shift
=
torch
.
clamp
(
shift
,
min
=
0
)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample
.
add_
(
shift
)
common_attn_metadata
.
seq_lens
.
sub_
(
shift
)
cached_lens
=
multi_layer_eagle_metadata
.
cached_len
shift
=
torch
.
minimum
(
shift
,
cached_lens
)
_multi_layer_eagle_shift_and_cache
(
batch_size
=
batch_size
,
max_shift
=
MAX_SHIFT
,
src_token_ids
=
target_token_ids
,
dst_token_ids
=
prev_token_ids
,
src_positions
=
target_positions
,
dst_positions
=
prev_positions
,
src_hidden_states
=
target_hidden_states
,
dst_hidden_states
=
prev_hidden_states
,
src_slot_mapping
=
slot_mapping
,
dst_slot_mapping
=
slot_mapping
,
start_token_indices
=
start_token_indices
,
end_token_indices
=
end_token_indices
,
token_indices_to_sample
=
token_indices_to_sample
,
shift
=
shift
,
cached_lens
=
cached_lens
,
cached_prev_token_ids
=
multi_layer_eagle_metadata
.
cached_token_ids
,
cached_prev_positions
=
multi_layer_eagle_metadata
.
cached_positions
,
cached_prev_hidden_states
=
multi_layer_eagle_metadata
.
cached_hidden_states
,
cached_slot_mappings
=
multi_layer_eagle_metadata
.
cached_slot_mappings
,
common_attn_metadata
=
common_attn_metadata
,
)
return
prev_token_ids
,
prev_positions
,
prev_hidden_states
,
common_attn_metadata
def
prepare_inputs
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampled_token_ids
:
list
[
list
[
int
]],
num_draft_tokens
:
list
[
int
],
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise
Exception
(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
,
use_cudagraphs
=
use_cudagraphs
)
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if
(
self
.
_draft_attn_layer_names
and
slot_mappings
is
not
None
and
next
(
iter
(
self
.
_draft_attn_layer_names
))
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
adjust_input_kwargs
=
{
"batch_size"
:
1
,
"target_token_ids"
:
self
.
input_ids
[:
num_input_tokens
],
"target_positions"
:
self
.
_get_positions
(
num_input_tokens
),
"target_hidden_states"
:
self
.
hidden_states
[:
num_input_tokens
],
"token_indices_to_sample"
:
torch
.
tensor
(
[
num_input_tokens
-
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"common_attn_metadata"
:
CommonAttentionMetadata
(
query_start_loc
=
torch
.
tensor
(
[
0
,
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
query_start_loc_cpu
=
torch
.
tensor
(
[
0
,
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
),
seq_lens
=
torch
.
tensor
(
[
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
num_reqs
=
1
,
num_actual_tokens
=
num_input_tokens
,
max_query_len
=
num_input_tokens
,
max_seq_len
=
self
.
max_model_len
,
block_table_tensor
=
torch
.
tensor
(
[],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
slot_mapping
=
self
.
arange
[:
num_input_tokens
],
logits_indices_padded
=
None
,
num_logits_indices
=
None
,
causal
=
True
,
encoder_seq_lens
=
None
,
),
"multi_layer_eagle_metadata"
:
MultiLayerEagleMetadata
.
make_dummy
(
layer_num
=
self
.
layer_num
,
hidden_size
=
self
.
hidden_size
,
device
=
self
.
device
,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self
.
adjust_input
(
**
adjust_input_kwargs
)
for
fwd_idx
in
range
(
self
.
layer_num
):
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
if
self
.
supports_mm_inputs
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
None
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"hidden_states"
:
self
.
hidden_states
[:
num_input_tokens
],
"inputs_embeds"
:
inputs_embeds
,
"spec_step_idx"
:
fwd_idx
,
}
self
.
model
(
**
model_kwargs
)
def
_multi_layer_eagle_shift_and_cache
(
*
,
batch_size
:
int
,
max_shift
:
int
,
src_token_ids
:
torch
.
Tensor
,
dst_token_ids
:
torch
.
Tensor
,
src_positions
:
torch
.
Tensor
,
dst_positions
:
torch
.
Tensor
,
src_hidden_states
:
torch
.
Tensor
,
dst_hidden_states
:
torch
.
Tensor
,
src_slot_mapping
:
torch
.
Tensor
,
dst_slot_mapping
:
torch
.
Tensor
,
start_token_indices
:
torch
.
Tensor
,
end_token_indices
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
shift
:
torch
.
Tensor
,
cached_lens
:
torch
.
Tensor
,
cached_prev_token_ids
:
torch
.
Tensor
,
cached_prev_positions
:
torch
.
Tensor
,
cached_prev_hidden_states
:
torch
.
Tensor
,
cached_slot_mappings
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
):
if
batch_size
==
0
:
return
assert
max_shift
>
0
assert
cached_prev_positions
.
is_contiguous
()
assert
cached_prev_token_ids
.
is_contiguous
()
assert
cached_prev_hidden_states
.
is_contiguous
()
assert
cached_slot_mappings
.
is_contiguous
()
assert
src_hidden_states
.
is_contiguous
()
assert
dst_hidden_states
.
is_contiguous
()
# If src/dst are the same tensor, shifting is unsafe without a separate src.
if
src_slot_mapping
.
data_ptr
()
==
dst_slot_mapping
.
data_ptr
():
src_slot_mapping
=
src_slot_mapping
.
clone
()
# Cache extraction for the next call.
store_start
=
torch
.
maximum
(
start_token_indices
,
(
token_indices_to_sample
+
1
-
max_shift
),
)
store_lens
=
torch
.
clamp
(
token_indices_to_sample
-
store_start
+
1
,
min
=
0
,
max
=
max_shift
,
)
# Avoid device sync: query length == (end - start + 1) == diff of
# query_start_loc (CPU copy).
max_window_len
=
int
(
(
common_attn_metadata
.
query_start_loc_cpu
[
1
:]
-
common_attn_metadata
.
query_start_loc_cpu
[:
-
1
]
)
.
max
()
.
item
()
)
num_blocks
=
max
(
1
,
(
max_window_len
+
BLOCK_TOKENS
-
1
)
//
BLOCK_TOKENS
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_token_ids
,
dst_token_ids
,
cached_prev_token_ids
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_slot_mapping
,
dst_slot_mapping
,
cached_slot_mappings
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_positions
,
dst_positions
,
cached_prev_positions
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
hidden_size
=
int
(
dst_hidden_states
.
shape
[
1
])
# Hidden blocking avoids extremely large Triton tiles (and huge cubins)
# when hidden_size is large.
num_hidden_blocks
=
max
(
1
,
(
hidden_size
+
BLOCK_HIDDEN
-
1
)
//
BLOCK_HIDDEN
)
_shift_and_gather_hidden_kernel
[(
batch_size
,
num_blocks
,
num_hidden_blocks
)](
src_hidden_states
,
dst_hidden_states
,
cached_prev_hidden_states
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
HIDDEN_SIZE
=
hidden_size
,
BLOCK_TOKENS
=
BLOCK_TOKENS
,
BLOCK_HIDDEN
=
BLOCK_HIDDEN
,
num_warps
=
4
,
)
cached_lens
.
copy_
(
store_lens
)
return
@
triton
.
jit
def
_shift_and_gather_cache_1d_kernel
(
src_ptr
,
dst_ptr
,
cached_ptr
,
start_ptr
,
end_ptr
,
shift_ptr
,
cached_len_ptr
,
store_start_ptr
,
store_len_ptr
,
MAX_SHIFT
:
tl
.
constexpr
,
PADDED_SHIFT
:
tl
.
constexpr
,
BLOCK_TOKENS
:
tl
.
constexpr
,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# We operate on a packed batch where each sequence (request) occupies a
# contiguous window [start, end] (inclusive) in a flattened tensor.
# For the next speculative step, we build a right-shifted version of each
# window. The shift amount can differ per sequence.
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
#
# The vacated prefix is filled from a small per-sequence cache (up to
# MAX_SHIFT elements) that stores values from previous speculative steps.
#
# Example:
# cached_tail = [a3, a4]
# src_window = [b0, b1, b2, b3, b4]
# shift = 2
# -> dst_window = [a3, a4, b0, b1, b2]
#
# After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst
# (specified by store_start / store_len) so the next call can populate its
# prefix from cache.
pid_seq
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
start
=
tl
.
load
(
start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
end
=
tl
.
load
(
end_ptr
+
pid_seq
).
to
(
tl
.
int32
)
shift
=
tl
.
load
(
shift_ptr
+
pid_seq
).
to
(
tl
.
int32
)
cached_len
=
tl
.
load
(
cached_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
assert
cached_len
>=
shift
# get dst indices
base
=
pid_blk
*
BLOCK_TOKENS
k
=
tl
.
arange
(
0
,
BLOCK_TOKENS
)
offs
=
base
+
k
dst_idx
=
start
+
offs
# get dst mask
window_len
=
end
-
start
+
1
mask
=
offs
<
window_len
# load from cached
base_cached
=
cached_ptr
+
pid_seq
*
MAX_SHIFT
cached_idx
=
cached_len
-
shift
+
offs
cached_mask
=
offs
<
shift
val_cached
=
tl
.
load
(
base_cached
+
cached_idx
,
mask
=
mask
&
cached_mask
,
other
=
0
)
# load from src
src_idx
=
start
+
offs
-
shift
val_src
=
tl
.
load
(
src_ptr
+
src_idx
,
mask
=
mask
&
~
cached_mask
,
other
=
0
)
# store to dst
val
=
tl
.
where
(
cached_mask
,
val_cached
,
val_src
)
tl
.
store
(
dst_ptr
+
dst_idx
,
val
,
mask
=
mask
)
# Store into the per-sequence cache.
#
# Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the
# full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the
# cache contiguous.
store_start
=
tl
.
load
(
store_start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
store_len
=
tl
.
load
(
store_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
m
=
tl
.
arange
(
0
,
PADDED_SHIFT
)
store_mask
=
m
<
MAX_SHIFT
dst_idx
=
store_start
+
m
val
=
tl
.
load
(
dst_ptr
+
dst_idx
,
mask
=
store_mask
&
(
m
<
store_len
),
other
=
0
)
tl
.
store
(
base_cached
+
m
,
val
,
mask
=
store_mask
)
@
triton
.
jit
def
_shift_and_gather_hidden_kernel
(
src_ptr
,
dst_ptr
,
cached_ptr
,
start_ptr
,
end_ptr
,
shift_ptr
,
cached_len_ptr
,
store_start_ptr
,
store_len_ptr
,
MAX_SHIFT
:
tl
.
constexpr
,
PADDED_SHIFT
:
tl
.
constexpr
,
HIDDEN_SIZE
:
tl
.
constexpr
,
BLOCK_TOKENS
:
tl
.
constexpr
,
BLOCK_HIDDEN
:
tl
.
constexpr
,
):
# Per-sequence "shift + gather" for hidden states.
#
# This kernel implements the same logical transformation as
# _shift_and_gather_cache_1d_kernel, but operates on hidden states with
# shape [num_tokens, hidden_size].
#
# Layout:
# - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size]
# - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size]
#
# For each sequence window [start, end] (inclusive) and its shift value, for
# 0-based index i within the window:
# - Prefix (i < shift):
# dst[start + i, :] = cached[seq, cached_len - shift + i, :]
# - Body (i >= shift):
# dst[start + i, :] = src[start + i - shift, :]
#
# We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid
# extremely large Triton tiles when hidden_size is large. As in the 1D
# kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next
# call can populate its prefix from cache.
pid_seq
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
pid_hid
=
tl
.
program_id
(
2
)
start
=
tl
.
load
(
start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
end
=
tl
.
load
(
end_ptr
+
pid_seq
).
to
(
tl
.
int32
)
shift
=
tl
.
load
(
shift_ptr
+
pid_seq
).
to
(
tl
.
int32
)
cached_len
=
tl
.
load
(
cached_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
assert
cached_len
>=
shift
# get dst indices
base
=
pid_blk
*
BLOCK_TOKENS
k
=
tl
.
arange
(
0
,
BLOCK_TOKENS
)
tok_offs
=
base
+
k
dst_tok
=
start
+
tok_offs
n
=
pid_hid
*
BLOCK_HIDDEN
+
tl
.
arange
(
0
,
BLOCK_HIDDEN
)
dst_ptrs
=
dst_ptr
+
dst_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
# get dst mask
window_len
=
end
-
start
+
1
tok_mask
=
tok_offs
<
window_len
n_mask
=
n
<
HIDDEN_SIZE
mask
=
tok_mask
[:,
None
]
&
n_mask
[
None
,
:]
# load from cached
base_cached
=
cached_ptr
+
pid_seq
*
HIDDEN_SIZE
*
MAX_SHIFT
cached_tok
=
cached_len
-
shift
+
tok_offs
cached_ptrs
=
base_cached
+
cached_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
cached_mask
=
tok_offs
<
shift
val_cached
=
tl
.
load
(
cached_ptrs
,
mask
=
mask
&
cached_mask
[:,
None
],
other
=
0
)
# load from src
src_tok
=
start
+
tok_offs
-
shift
src_ptrs
=
src_ptr
+
src_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
val_src
=
tl
.
load
(
src_ptrs
,
mask
=
mask
&
~
cached_mask
[:,
None
],
other
=
0
)
# store to dst
val
=
tl
.
where
(
cached_mask
[:,
None
],
val_cached
,
val_src
)
tl
.
store
(
dst_ptrs
,
val
,
mask
=
mask
)
# store to cached
store_start
=
tl
.
load
(
store_start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
store_len
=
tl
.
load
(
store_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
m
=
tl
.
arange
(
0
,
PADDED_SHIFT
)
m_mask
=
(
m
<
MAX_SHIFT
)
&
(
m
<
store_len
)
store_tok
=
store_start
+
m
dst_ptrs
=
dst_ptr
+
store_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
store_ptrs
=
base_cached
+
m
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
mask
=
m_mask
[:,
None
]
&
n_mask
[
None
,
:]
val
=
tl
.
load
(
dst_ptrs
,
mask
=
mask
,
other
=
0
)
tl
.
store
(
store_ptrs
,
val
,
mask
=
mask
)
vllm/v1/spec_decode/ngram_proposer_gpu.py
0 → 100755
View file @
fcc9c9ea
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import
torch
from
torch
import
nn
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CompilationConfig
,
CompilationMode
,
CUDAGraphMode
,
VllmConfig
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
@
support_torch_compile
()
class
NgramGPUKernel
(
nn
.
Module
):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
,
device
:
torch
.
device
=
"cuda"
):
super
().
__init__
()
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
def
_find_first_and_extract_all_n_parallel
(
self
,
token_ids
:
torch
.
Tensor
,
seq_lengths
:
torch
.
Tensor
,
min_ngram_len
:
int
,
max_ngram_len
:
int
,
num_draft_tokens
:
int
,
)
->
torch
.
Tensor
:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size
=
token_ids
.
shape
[
0
]
max_seq_len
=
token_ids
.
shape
[
1
]
device
=
token_ids
.
device
num_ngram_sizes
=
max_ngram_len
-
min_ngram_len
+
1
# All n-gram sizes to try.
ngram_lengths
=
torch
.
arange
(
min_ngram_len
,
max_ngram_len
+
1
,
device
=
device
)
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
device
)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions
=
torch
.
full
(
(
batch_size
,
num_ngram_sizes
),
-
1
,
dtype
=
torch
.
long
,
device
=
device
)
for
i
,
ngram_len
in
enumerate
(
range
(
min_ngram_len
,
max_ngram_len
+
1
)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows
=
token_ids
.
unfold
(
1
,
ngram_len
,
1
)
num_windows
=
search_windows
.
shape
[
1
]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts
=
seq_lengths
-
ngram_len
suffix_indices
=
suffix_starts
.
unsqueeze
(
1
)
+
torch
.
arange
(
ngram_len
,
device
=
device
)
suffix
=
torch
.
gather
(
token_ids
,
1
,
suffix_indices
.
clamp
(
min
=
0
))
# Window matches for each sequence.
matches
=
(
search_windows
==
suffix
.
unsqueeze
(
1
)).
all
(
dim
=-
1
)
# Match must leave room for at least one draft token.
max_valid_suffix_start
=
seq_lengths
-
ngram_len
-
1
window_positions
=
torch
.
arange
(
num_windows
,
device
=
device
)
valid_mask
=
window_positions
<=
max_valid_suffix_start
.
unsqueeze
(
1
)
final_matches
=
matches
&
valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx
=
torch
.
argmax
(
final_matches
.
int
(),
dim
=
1
)
has_match
=
final_matches
[
batch_indices
,
first_match_idx
]
# Store valid match positions (window index = position).
first_match_positions
[:,
i
]
=
torch
.
where
(
has_match
,
first_match_idx
,
-
1
)
# Select the longest n-gram with a match.
best_ngram_idx
=
(
first_match_positions
>=
0
).
int
().
flip
(
dims
=
[
1
]).
argmax
(
dim
=
1
)
best_ngram_idx
=
num_ngram_sizes
-
1
-
best_ngram_idx
# Flip back
# Match position for the best n-gram.
best_match_pos
=
first_match_positions
[
batch_indices
,
best_ngram_idx
]
# Avoid data-dependent branching.
has_any_match
=
best_match_pos
>=
0
# Length of the best matching n-gram.
best_ngram_lengths
=
ngram_lengths
[
best_ngram_idx
]
# Start position right after the matched suffix.
draft_start
=
torch
.
where
(
has_any_match
,
best_match_pos
+
best_ngram_lengths
,
torch
.
zeros_like
(
best_match_pos
),
)
tokens_available
=
seq_lengths
-
draft_start
# Gather indices for draft tokens.
draft_indices
=
draft_start
.
unsqueeze
(
1
)
+
torch
.
arange
(
num_draft_tokens
,
device
=
device
)
draft_indices
=
draft_indices
.
clamp
(
min
=
0
,
max
=
max_seq_len
-
1
)
# Extract draft tokens; gather always runs.
draft_tokens
=
torch
.
gather
(
token_ids
,
1
,
draft_indices
)
# Mask positions beyond available tokens.
position_indices
=
torch
.
arange
(
num_draft_tokens
,
device
=
device
).
unsqueeze
(
0
)
valid_positions
=
position_indices
<
tokens_available
.
unsqueeze
(
1
)
draft_tokens
=
torch
.
where
(
valid_positions
,
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
# If no match, mask all positions.
draft_tokens
=
torch
.
where
(
has_any_match
.
unsqueeze
(
1
),
draft_tokens
,
torch
.
full_like
(
draft_tokens
,
-
1
),
)
return
draft_tokens
def
forward
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
token_ids_gpu
:
torch
.
Tensor
,
combined_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device
=
token_ids_gpu
.
device
# Infer batch size to preserve dynamic shape.
actual_batch_size
=
token_ids_gpu
.
shape
[
0
]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens
=
torch
.
full
(
(
actual_batch_size
,
self
.
k
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
results
=
self
.
_find_first_and_extract_all_n_parallel
(
token_ids_gpu
,
num_tokens_no_spec
,
min_ngram_len
=
self
.
min_n
,
max_ngram_len
=
self
.
max_n
,
num_draft_tokens
=
self
.
k
,
)
draft_tokens
=
torch
.
where
(
combined_mask
.
unsqueeze
(
1
),
results
,
-
1
)
# Count leading contiguous valid (non -1) tokens per request.
is_valid
=
draft_tokens
!=
-
1
# [batch, k]
cum_valid
=
is_valid
.
int
().
cumsum
(
dim
=
1
)
# [batch, k]
positions
=
torch
.
arange
(
1
,
self
.
k
+
1
,
device
=
device
).
unsqueeze
(
0
)
num_valid_draft_tokens
=
(
cum_valid
==
positions
).
int
().
sum
(
dim
=
1
)
return
draft_tokens
,
num_valid_draft_tokens
def
load_model
(
self
,
*
args
,
**
kwargs
):
"""No model to load for N-gram proposer."""
pass
class
NgramProposerGPU
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
):
assert
vllm_config
.
speculative_config
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_min
is
not
None
assert
vllm_config
.
speculative_config
.
prompt_lookup_max
is
not
None
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
custom_ops
=
[
"none"
],
splitting_ops
=
[],
compile_sizes
=
[],
inductor_compile_config
=
{
"enable_auto_functionalized_v2"
:
False
,
"max_autotune"
:
True
,
"aggressive_fusion"
:
True
,
"triton.autotune_pointwise"
:
True
,
"coordinate_descent_tuning"
:
True
,
"use_mixed_mm"
:
False
,
},
cudagraph_mode
=
CUDAGraphMode
.
NONE
,
)
model_config
=
vllm_config
.
model_config
speculative_config
=
vllm_config
.
speculative_config
scheduler_config
=
vllm_config
.
scheduler_config
self
.
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
,
model_config
=
model_config
,
speculative_config
=
speculative_config
,
scheduler_config
=
scheduler_config
,
)
self
.
min_n
=
vllm_config
.
speculative_config
.
prompt_lookup_min
self
.
max_n
=
vllm_config
.
speculative_config
.
prompt_lookup_max
self
.
k
=
vllm_config
.
speculative_config
.
num_speculative_tokens
self
.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
self
.
max_num_seqs
=
vllm_config
.
scheduler_config
.
max_num_seqs
self
.
device
=
device
self
.
kernel
=
NgramGPUKernel
(
vllm_config
=
self
.
vllm_config
,
prefix
=
"ngram_gpu_kernel"
,
device
=
device
)
self
.
kernel
.
to
(
device
)
self
.
kernel
.
eval
()
self
.
_dummy_run
()
def
_dummy_run
(
self
):
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
=
self
.
_generate_dummy_data
(
batch_size
=
self
.
max_num_seqs
,
max_seq_len
=
self
.
max_model_len
,
pattern_len
=
self
.
k
,
device
=
self
.
device
,
)
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens
>=
self
.
min_n
)
for
_
in
range
(
3
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
_
,
_
=
self
.
kernel
(
num_tokens
,
token_ids
,
combined_mask
)
def
_generate_dummy_data
(
self
,
batch_size
:
int
,
max_seq_len
:
int
,
pattern_len
:
int
,
device
:
str
=
"cuda"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids
=
torch
.
zeros
(
batch_size
,
max_seq_len
,
dtype
=
torch
.
int32
,
device
=
device
,
)
num_tokens
=
torch
.
randint
(
pattern_len
,
max_seq_len
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
sampled_flags
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
token_ids
,
num_tokens
,
sampled_flags
,
valid_mask
def
propose
(
self
,
num_tokens_no_spec
:
torch
.
Tensor
,
# [batch_size]
token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, max_len]
valid_sampled_token_ids_gpu
:
torch
.
Tensor
,
# [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count
:
torch
.
Tensor
,
# [batch_size]
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert
token_ids_gpu
.
device
==
self
.
device
assert
num_tokens_no_spec
.
device
==
self
.
device
batch_size
=
num_tokens_no_spec
.
shape
[
0
]
max_seq_len
=
token_ids_gpu
.
shape
[
1
]
max_new_tokens
=
valid_sampled_token_ids_gpu
.
shape
[
1
]
# num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets
=
torch
.
arange
(
max_new_tokens
,
device
=
self
.
device
)
write_positions
=
num_tokens_no_spec
.
unsqueeze
(
1
)
+
offsets
.
unsqueeze
(
0
)
valid_write_mask
=
offsets
.
unsqueeze
(
0
)
<
valid_sampled_tokens_count
.
unsqueeze
(
1
)
in_bounds
=
write_positions
<
max_seq_len
scatter_mask
=
(
valid_write_mask
&
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
in_bounds
)
write_positions_long
=
write_positions
.
clamp
(
max
=
max_seq_len
-
1
).
long
()
existing_values
=
token_ids_gpu
.
gather
(
1
,
write_positions_long
)
tokens_cast
=
valid_sampled_token_ids_gpu
.
to
(
token_ids_gpu
.
dtype
)
tokens_to_scatter
=
torch
.
where
(
scatter_mask
,
tokens_cast
,
existing_values
,
)
token_ids_gpu
.
scatter_
(
1
,
write_positions_long
,
tokens_to_scatter
)
num_tokens_tmp
=
num_tokens_no_spec
+
valid_sampled_tokens_count
# Compute validity masks.
sampled_flags
=
valid_sampled_tokens_count
>
0
valid_mask
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
self
.
device
)
with
set_forward_context
(
None
,
self
.
vllm_config
):
combined_mask
=
sampled_flags
&
valid_mask
&
(
num_tokens_tmp
>=
self
.
min_n
)
with
record_function_or_nullcontext
(
"ngram_proposer_gpu: kernel"
):
draft_tokens
,
num_valid_draft_tokens
=
self
.
kernel
(
num_tokens_tmp
,
token_ids_gpu
,
combined_mask
,
)
return
draft_tokens
,
num_valid_draft_tokens
def
update_token_ids_ngram
(
self
,
sampled_token_ids
:
torch
.
Tensor
|
list
[
list
[
int
]],
gpu_input_batch
:
InputBatch
,
token_ids_gpu
:
torch
.
Tensor
,
num_tokens_no_spec
:
torch
.
Tensor
,
discard_request_mask
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs
=
gpu_input_batch
.
num_reqs
if
isinstance
(
sampled_token_ids
,
list
):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len
=
max
(
(
len
(
sublist
)
for
sublist
in
sampled_token_ids
),
default
=
0
,
)
# Ensure at least length 1 for tensor creation
max_len
=
max
(
max_len
,
1
)
padded_list
=
[
sublist
+
[
-
1
]
*
(
max_len
-
len
(
sublist
))
for
sublist
in
sampled_token_ids
]
sampled_token_ids
=
torch
.
tensor
(
padded_list
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
(
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices
=
(
num_tokens_no_spec
[:
num_reqs
]
-
1
).
clamp
(
min
=
0
).
long
()
backup_next_token_ids
=
torch
.
gather
(
token_ids_gpu
[:
num_reqs
],
dim
=
1
,
index
=
backup_indices
.
unsqueeze
(
1
)
).
squeeze
(
1
)
valid_sampled_token_ids_gpu
=
sampled_token_ids
.
clone
()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded
=
discard_request_mask
[:
num_reqs
].
unsqueeze
(
1
)
valid_sampled_token_ids_gpu
.
masked_fill_
(
discard_mask_expanded
,
-
1
)
# Mask valid tokens within each request.
valid_mask
=
(
valid_sampled_token_ids_gpu
!=
-
1
)
&
(
valid_sampled_token_ids_gpu
<
gpu_input_batch
.
vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count
=
valid_mask
.
sum
(
dim
=
1
)
# Rightmost valid index per row.
last_valid_indices
=
valid_sampled_tokens_count
-
1
last_valid_indices_safe
=
torch
.
clamp
(
last_valid_indices
,
min
=
0
)
# Last valid token from each row; undefined if none.
selected_tokens
=
torch
.
gather
(
valid_sampled_token_ids_gpu
,
1
,
last_valid_indices_safe
.
unsqueeze
(
1
)
).
squeeze
(
1
)
# Use last token if valid; otherwise fallback to backup.
next_token_ids
=
torch
.
where
(
last_valid_indices
!=
-
1
,
selected_tokens
,
backup_next_token_ids
,
)
return
next_token_ids
,
valid_sampled_tokens_count
,
valid_sampled_token_ids_gpu
def
load_model
(
self
,
*
args
,
**
kwargs
):
self
.
kernel
.
load_model
(
*
args
,
**
kwargs
)
def
update_scheduler_for_invalid_drafts
(
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
scheduler_output
:
"SchedulerOutput"
,
req_id_to_index
:
dict
[
str
,
int
],
)
->
None
:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data
=
scheduler_output
.
scheduled_cached_reqs
num_valid_draft_tokens_event
.
synchronize
()
for
req_id
in
req_data
.
req_ids
:
req_index
=
req_id_to_index
.
get
(
req_id
)
if
req_index
is
None
:
continue
spec_token_ids
=
scheduler_output
.
scheduled_spec_decode_tokens
.
get
(
req_id
)
if
spec_token_ids
is
None
:
continue
scheduled_k
=
len
(
spec_token_ids
)
valid_k
=
int
(
num_valid_draft_tokens_cpu
[
req_index
].
item
())
valid_k
=
max
(
0
,
min
(
valid_k
,
scheduled_k
))
tokens_to_trim
=
scheduled_k
-
valid_k
scheduler_output
.
total_num_scheduled_tokens
-=
tokens_to_trim
scheduler_output
.
num_scheduled_tokens
[
req_id
]
-=
tokens_to_trim
if
valid_k
==
0
:
scheduler_output
.
scheduled_spec_decode_tokens
.
pop
(
req_id
,
None
)
else
:
scheduler_output
.
scheduled_spec_decode_tokens
[
req_id
]
=
spec_token_ids
[
:
valid_k
]
def
update_ngram_gpu_tensors_incremental
(
input_batch
:
InputBatch
,
token_ids_gpu_tensor
:
torch
.
Tensor
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
new_reqs
:
list
[
CachedRequestState
],
device
:
torch
.
device
,
_pinned_idx_buf
:
torch
.
Tensor
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index
=
input_batch
.
prev_req_id_to_index
curr_req_id_to_index
=
input_batch
.
req_id_to_index
if
not
curr_req_id_to_index
:
return
active_indices
=
list
(
curr_req_id_to_index
.
values
())
n_active
=
len
(
active_indices
)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu
=
_pinned_idx_buf
[:
n_active
]
active_idx_cpu
.
copy_
(
torch
.
as_tensor
(
active_indices
,
dtype
=
torch
.
long
))
active_idx_gpu
=
active_idx_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
new_req_ids
=
{
req
.
req_id
for
req
in
new_reqs
}
# First run, no previous state.
if
prev_req_id_to_index
is
None
:
for
idx
in
active_indices
:
num_tokens
=
input_batch
.
num_tokens_no_spec
[
idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
idx
,
:
num_tokens
],
non_blocking
=
True
,
)
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
return
# Detect index changes for reorder.
reorder_src
:
list
[
int
]
=
[]
reorder_dst
:
list
[
int
]
=
[]
for
req_id
,
curr_idx
in
curr_req_id_to_index
.
items
():
if
req_id
in
new_req_ids
:
continue
prev_idx
=
prev_req_id_to_index
.
get
(
req_id
)
if
prev_idx
is
not
None
and
prev_idx
!=
curr_idx
:
reorder_src
.
append
(
prev_idx
)
reorder_dst
.
append
(
curr_idx
)
if
reorder_src
:
src_tensor
=
torch
.
tensor
(
reorder_src
,
dtype
=
torch
.
long
,
device
=
device
)
dst_tensor
=
torch
.
tensor
(
reorder_dst
,
dtype
=
torch
.
long
,
device
=
device
)
temp_token_ids
=
token_ids_gpu_tensor
[
src_tensor
].
clone
()
temp_num_tokens
=
num_tokens_no_spec_gpu
[
src_tensor
].
clone
()
token_ids_gpu_tensor
[
dst_tensor
]
=
temp_token_ids
num_tokens_no_spec_gpu
[
dst_tensor
]
=
temp_num_tokens
# Full copy for new/resumed requests.
for
req_state
in
new_reqs
:
new_req_idx
=
curr_req_id_to_index
.
get
(
req_state
.
req_id
)
if
new_req_idx
is
None
:
continue
num_tokens
=
input_batch
.
num_tokens_no_spec
[
new_req_idx
]
if
num_tokens
>
0
:
token_ids_gpu_tensor
[
new_req_idx
,
:
num_tokens
].
copy_
(
input_batch
.
token_ids_cpu_tensor
[
new_req_idx
,
:
num_tokens
],
non_blocking
=
True
,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens
(
input_batch
,
num_tokens_no_spec_gpu
,
active_idx_cpu
,
active_idx_gpu
,
n_active
,
device
,
_pinned_val_buf
,
)
def
_sync_num_tokens
(
input_batch
:
InputBatch
,
num_tokens_no_spec_gpu
:
torch
.
Tensor
,
active_idx_cpu
:
torch
.
Tensor
,
active_idx_gpu
:
torch
.
Tensor
,
n_active
:
int
,
device
:
torch
.
device
,
_pinned_val_buf
:
torch
.
Tensor
,
)
->
None
:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu
=
input_batch
.
num_tokens_no_spec_cpu_tensor
vals
=
_pinned_val_buf
[:
n_active
]
vals
.
copy_
(
src_cpu
.
index_select
(
0
,
active_idx_cpu
))
num_tokens_no_spec_gpu
.
index_copy_
(
0
,
active_idx_gpu
,
vals
.
to
(
device
=
device
,
non_blocking
=
True
),
)
def
copy_num_valid_draft_tokens
(
num_valid_draft_tokens_cpu
:
torch
.
Tensor
,
num_valid_draft_tokens_copy_stream
:
torch
.
cuda
.
Stream
,
num_valid_draft_tokens_event
:
torch
.
cuda
.
Event
,
num_valid_draft_tokens
:
torch
.
Tensor
|
None
,
batch_size
:
int
,
)
->
None
:
"""
Async D2H copy of per-request valid draft counts.
"""
if
num_valid_draft_tokens
is
None
:
return
num_reqs_to_copy
=
min
(
batch_size
,
num_valid_draft_tokens
.
shape
[
0
])
if
num_reqs_to_copy
<=
0
:
return
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
num_valid_draft_tokens_copy_stream
):
num_valid_draft_tokens_copy_stream
.
wait_stream
(
default_stream
)
num_valid_draft_tokens_cpu
[:
num_reqs_to_copy
].
copy_
(
num_valid_draft_tokens
[:
num_reqs_to_copy
],
non_blocking
=
True
)
num_valid_draft_tokens_event
.
record
()
vllm/v1/spec_decode/utils.py
View file @
fcc9c9ea
...
...
@@ -5,7 +5,11 @@ import torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
)
PADDING_SLOT_ID
=
-
1
@
triton
.
jit
def
eagle_prepare_inputs_padded_kernel
(
...
...
@@ -182,3 +186,219 @@ class DraftProbs(ABC): # type: ignore[call-arg]
target_device
=
self
.
draft_probs
.
device
,
pin_memory
=
True
)
return
self
.
draft_probs
[
index_tensor
]
def
compute_new_slot_mapping
(
cad
:
CommonAttentionMetadata
,
new_positions
:
torch
.
Tensor
,
is_rejected_token_mask
:
torch
.
Tensor
,
block_size
:
int
,
num_new_tokens
:
int
,
max_model_len
:
int
,
):
batch_size
,
n_blocks_per_req
=
cad
.
block_table_tensor
.
shape
req_indices
=
torch
.
arange
(
batch_size
,
device
=
cad
.
query_start_loc
.
device
)
req_indices
=
torch
.
repeat_interleave
(
req_indices
,
cad
.
naive_query_lens
()
+
num_new_tokens
,
output_size
=
len
(
new_positions
),
)
# Clamp the positions to prevent an out-of-bounds error when indexing
# into block_table_tensor.
clamped_positions
=
torch
.
clamp
(
new_positions
,
max
=
max_model_len
-
1
)
block_table_indices
=
(
req_indices
*
n_blocks_per_req
+
clamped_positions
//
block_size
)
block_nums
=
cad
.
block_table_tensor
.
view
(
-
1
)[
block_table_indices
]
block_offsets
=
clamped_positions
%
block_size
new_slot_mapping
=
block_nums
*
block_size
+
block_offsets
# Mask out the position ids that exceed the max model length.
exceeds_max_model_len
=
new_positions
>=
max_model_len
new_slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# Mask out rejected tokens to prevent saves to the KV cache.
new_slot_mapping
.
masked_fill_
(
is_rejected_token_mask
,
PADDING_SLOT_ID
)
return
new_slot_mapping
def
extend_all_queries_by_N
(
common_attn_metadata
:
CommonAttentionMetadata
,
N
:
int
,
arange
:
torch
.
Tensor
,
new_slot_mapping
:
torch
.
Tensor
,
)
->
CommonAttentionMetadata
:
"""
Creates a new CommonAttentionMetadata with all query lengths increased by N.
Also all seq lens are increased by N.
This is useful e.g. in speculative decoding with parallel drafting, where we
extend each sequence by N tokens and predict all tokens in one pass.
The slot mapping is computed externally, as it requires more information.
"""
cad
=
common_attn_metadata
# query start loc must be increased by [+0, +N, +2N, ..., +batch_size * N]
new_query_start_loc
=
cad
.
query_start_loc
+
N
*
arange
[:
len
(
cad
.
query_start_loc
)]
new_query_start_loc_cpu
=
cad
.
query_start_loc_cpu
+
N
*
torch
.
arange
(
len
(
cad
.
query_start_loc_cpu
),
dtype
=
torch
.
int32
)
new_cad
=
cad
.
replace
(
query_start_loc
=
new_query_start_loc
,
query_start_loc_cpu
=
new_query_start_loc_cpu
,
seq_lens
=
cad
.
seq_lens
+
N
,
# each request is extended by N tokens -> batch_size * N tokens are added
num_actual_tokens
=
cad
.
num_actual_tokens
+
cad
.
batch_size
()
*
N
,
# All query lens increase by N, so max query len increases by N
max_query_len
=
cad
.
max_query_len
+
N
,
max_seq_len
=
cad
.
max_seq_len
+
N
,
slot_mapping
=
new_slot_mapping
,
)
return
new_cad
# Unified copy/expand kernel
@
triton
.
jit
def
copy_and_expand_eagle_inputs_kernel
(
# (Padded) Inputs from the target model
target_token_ids_ptr
,
# [total_tokens_in_batch]
target_positions_ptr
,
# [total_tokens_in_batch]
next_token_ids_ptr
,
# [num_reqs]
# Outputs to the drafting buffers
out_input_ids_ptr
,
# [total_draft_tokens_in_batch] (output)
out_positions_ptr
,
# [total_draft_tokens_in_batch] (output)
out_is_rejected_token_mask_ptr
,
# [total_draft_tokens_in_batch] (output)
out_is_masked_token_mask_ptr
,
# [total_draft_tokens_in_batch] (output)
out_new_token_indices_ptr
,
# [num_padding_slots_per_request * num_reqs] (output)
out_hidden_state_mapping_ptr
,
# [total_tokens_in_batch]
# Input metadata
query_start_loc_ptr
,
# [num_reqs + 1], last value is the total num input tokens
query_end_loc_ptr
,
# [num_reqs]
padding_token_id
,
# tl.int32
parallel_drafting_token_id
,
# tl.int32
# Sizing info
total_input_tokens
,
# tl.int32
num_padding_slots_per_request
,
# tl.int32
shift_input_ids
,
# tl.bool
BLOCK_SIZE_TOKENS
:
tl
.
constexpr
,
# Blocks along token dim to handle prefills
):
"""
Copy and expand inputs from the target model to the drafting buffers for Eagle
speculative decoding. This kernel handles padding slots and parallel drafting
tokens, if enabled.
"""
request_idx
=
tl
.
program_id
(
axis
=
0
)
token_batch_idx
=
tl
.
program_id
(
axis
=
1
)
# Load query locations
query_start_loc
=
tl
.
load
(
query_start_loc_ptr
+
request_idx
)
next_query_start_loc
=
tl
.
load
(
query_start_loc_ptr
+
request_idx
+
1
)
query_end_loc
=
tl
.
load
(
query_end_loc_ptr
+
request_idx
)
# Calculate number of valid tokens to copy and input offset
# With shift_input_ids=True, we skip the first token
# Output layout: each request gets (input_len + num_padding_slots_per_request) slots
# But with shift, we lose one token per request
if
shift_input_ids
:
num_valid_tokens
=
query_end_loc
-
query_start_loc
input_offset
=
1
output_start
=
query_start_loc
+
request_idx
*
(
num_padding_slots_per_request
-
1
)
else
:
num_valid_tokens
=
query_end_loc
-
query_start_loc
+
1
input_offset
=
0
output_start
=
query_start_loc
+
request_idx
*
num_padding_slots_per_request
# Number of rejected tokens from previous speculation
num_rejected
=
next_query_start_loc
-
query_end_loc
-
1
# Total output tokens for this request
total_output_tokens
=
(
num_valid_tokens
+
num_padding_slots_per_request
+
num_rejected
)
# Process tokens in this block
j
=
token_batch_idx
*
BLOCK_SIZE_TOKENS
+
tl
.
arange
(
0
,
BLOCK_SIZE_TOKENS
)
# Compute masks for different output regions:
# [0, num_valid_tokens): valid tokens copied from input
# [num_valid_tokens]: bonus token from next_token_ids
# (num_valid_tokens, num_valid_tokens + num_padding_slots_per_request):
# parallel drafting slots
# [num_valid_tokens + num_padding_slots_per_request, total_output_tokens):
# rejected slots
in_bounds
=
j
<
total_output_tokens
is_valid_region
=
j
<
num_valid_tokens
is_bonus_region
=
j
==
num_valid_tokens
is_parallel_draft_region
=
(
j
>
num_valid_tokens
)
&
(
j
<
num_valid_tokens
+
num_padding_slots_per_request
)
is_rejected_region
=
j
>=
num_valid_tokens
+
num_padding_slots_per_request
# Compute output indices
out_idx
=
output_start
+
j
# For valid tokens, compute input index
in_idx
=
query_start_loc
+
input_offset
+
j
# Clamp to avoid out-of-bounds access (masked loads still need valid addresses)
in_idx_clamped
=
tl
.
minimum
(
in_idx
,
total_input_tokens
-
1
)
# Load input tokens (masked to valid region)
token_ids
=
tl
.
load
(
target_token_ids_ptr
+
in_idx_clamped
,
mask
=
is_valid_region
&
in_bounds
,
other
=
0
)
# Load the starting position for this request (first position in the sequence)
start_pos
=
tl
.
load
(
target_positions_ptr
+
query_start_loc
)
# Load bonus token for this request
bonus_token
=
tl
.
load
(
next_token_ids_ptr
+
request_idx
)
# Build final token_ids based on region
token_ids
=
tl
.
where
(
is_bonus_region
,
bonus_token
,
token_ids
)
token_ids
=
tl
.
where
(
is_parallel_draft_region
,
parallel_drafting_token_id
,
token_ids
)
token_ids
=
tl
.
where
(
is_rejected_region
,
padding_token_id
,
token_ids
)
# Build final positions:
# Positions are NOT shifted - they start from the first input position and increment
# Output position j gets start_pos + j
# (e.g., input positions [5,6,7] -> output [5,6,7,8,9,...])
positions
=
start_pos
+
j
# Rejected positions are don't-care, set to 0
positions
=
tl
.
where
(
is_rejected_region
,
0
,
positions
)
# Compute output masks
is_rejected_out
=
is_rejected_region
&
in_bounds
is_masked_out
=
is_parallel_draft_region
&
in_bounds
# Compute indices of new tokens (bonus + parallel drafting) for sampling
# New tokens are at positions
# [num_valid_tokens, num_valid_tokens + num_padding_slots_per_request)
is_new_token_region
=
(
j
>=
num_valid_tokens
)
&
(
j
<
num_valid_tokens
+
num_padding_slots_per_request
)
new_token_local_idx
=
(
j
-
num_valid_tokens
)
# 0 for bonus, 1, 2, ... for parallel drafting
new_token_out_idx
=
(
request_idx
*
num_padding_slots_per_request
+
new_token_local_idx
)
# Compute hidden state mapping (source index -> destination index)
# This maps each input position to its corresponding output position
# Hidden states don't get shifted, so we map all input tokens (including rejected)
if
shift_input_ids
:
num_input_tokens_this_request
=
next_query_start_loc
-
query_start_loc
is_input_region
=
j
<
num_input_tokens_this_request
src_idx
=
query_start_loc
+
j
tl
.
store
(
out_hidden_state_mapping_ptr
+
src_idx
,
out_idx
,
mask
=
is_input_region
)
# Store outputs
tl
.
store
(
out_input_ids_ptr
+
out_idx
,
token_ids
,
mask
=
in_bounds
)
tl
.
store
(
out_positions_ptr
+
out_idx
,
positions
,
mask
=
in_bounds
)
tl
.
store
(
out_is_rejected_token_mask_ptr
+
out_idx
,
is_rejected_out
,
mask
=
in_bounds
)
tl
.
store
(
out_is_masked_token_mask_ptr
+
out_idx
,
is_masked_out
,
mask
=
in_bounds
)
tl
.
store
(
out_new_token_indices_ptr
+
new_token_out_idx
,
out_idx
,
mask
=
is_new_token_region
&
in_bounds
,
)
\ No newline at end of file
vllm/v1/worker/gpu_input_batch.py
View file @
fcc9c9ea
...
...
@@ -61,6 +61,13 @@ class CachedRequestState:
pooling_params
:
PoolingParams
|
None
=
None
pooling_states
:
PoolingStates
|
None
=
None
# for multi layer eagle proposer
cached_len
:
torch
.
Tensor
|
None
=
None
cached_token_ids
:
torch
.
Tensor
|
None
=
None
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
cached_positions
:
torch
.
Tensor
|
None
=
None
def
__post_init__
(
self
):
self
.
num_prompt_tokens
=
length_from_prompt_token_ids_or_embeds
(
self
.
prompt_token_ids
,
self
.
prompt_embeds
...
...
@@ -103,6 +110,8 @@ class InputBatch:
is_spec_decode
:
bool
=
False
,
is_pooling_model
:
bool
=
False
,
cp_kv_cache_interleave_size
:
int
=
1
,
multi_layer_eagle_num
:
int
=
0
,
hidden_size
:
int
|
None
=
None
,
):
ori_max_num_reqs
=
max_num_reqs
if
is_spec_decode
and
envs
.
VLLM_REJECT_SAMPLE_OPT
:
...
...
@@ -223,7 +232,45 @@ class InputBatch:
(
max_num_reqs
,),
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
num_accepted_tokens_cpu
=
self
.
num_accepted_tokens_cpu_tensor
.
numpy
()
# Multi layer eagle
self
.
multi_layer_eagle_num
=
multi_layer_eagle_num
if
multi_layer_eagle_num
>
0
:
self
.
cached_len
=
torch
.
zeros
(
(
max_num_reqs
,),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
cached_token_ids
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
cached_hidden_states
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
hidden_size
,
),
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
cached_slot_mappings
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
cached_positions
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,),
dtype
=
np
.
int64
)
self
.
lora_id_to_request_ids
:
dict
[
int
,
set
[
str
]]
=
{}
...
...
@@ -464,6 +511,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self
.
num_accepted_tokens_cpu
[
req_index
]
=
1
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
req_index
]
=
request
.
cached_len
self
.
cached_token_ids
[
req_index
]
=
request
.
cached_token_ids
self
.
cached_hidden_states
[
req_index
]
=
request
.
cached_hidden_states
self
.
cached_slot_mappings
[
req_index
]
=
request
.
cached_slot_mappings
self
.
cached_positions
[
req_index
]
=
request
.
cached_positions
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
...
...
@@ -662,6 +716,20 @@ class InputBatch:
self
.
allowed_token_ids_mask_cpu_tensor
[
i1
],
)
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
i1
],
self
.
cached_len
[
i2
]
=
(
self
.
cached_len
[
i2
],
self
.
cached_len
[
i1
],
)
self
.
cached_token_ids
[[
i1
,
i2
],
...]
=
self
.
cached_token_ids
[[
i2
,
i1
],
...]
self
.
cached_hidden_states
[[
i1
,
i2
],
...]
=
self
.
cached_hidden_states
[
[
i2
,
i1
],
...
]
self
.
cached_slot_mappings
[[
i1
,
i2
],
...]
=
self
.
cached_slot_mappings
[
[
i2
,
i1
],
...
]
self
.
cached_positions
[[
i1
,
i2
],
...]
=
self
.
cached_positions
[[
i2
,
i1
],
...]
def
condense
(
self
)
->
None
:
"""Slide non-empty requests down into lower, empty indices.
...
...
@@ -784,6 +852,21 @@ class InputBatch:
if
bad_words_token_ids
is
not
None
:
self
.
bad_words_token_ids
[
empty_index
]
=
bad_words_token_ids
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
empty_index
]
=
self
.
cached_len
[
last_req_index
]
self
.
cached_token_ids
[
empty_index
]
=
self
.
cached_token_ids
[
last_req_index
]
self
.
cached_hidden_states
[
empty_index
]
=
self
.
cached_hidden_states
[
last_req_index
]
self
.
cached_slot_mappings
[
empty_index
]
=
self
.
cached_slot_mappings
[
last_req_index
]
self
.
cached_positions
[
empty_index
]
=
self
.
cached_positions
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fcc9c9ea
...
...
@@ -149,8 +149,15 @@ from vllm.v1.sample.rejection_sampler_opt import OptRejectionSampler
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.spec_decode.draft_model
import
DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
,
MultiLayerEagleMetadata
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
copy_num_valid_draft_tokens
,
# update_ngram_gpu_tensors_incremental,
# update_scheduler_for_invalid_drafts,
)
from
vllm.v1.spec_decode.multi_layer_eagle
import
MultiLayerEagleProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.suffix_decoding
import
SuffixDecodingProposer
from
vllm.v1.structured_output.utils
import
apply_grammar_bitmask
...
...
@@ -316,6 +323,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output
:
"SchedulerOutput"
logits
:
torch
.
Tensor
spec_decode_metadata
:
SpecDecodeMetadata
|
None
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
spec_decode_common_attn_metadata
:
CommonAttentionMetadata
|
None
hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
...
...
@@ -336,6 +344,7 @@ class GPUModelRunner(
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
# self.offload_config = vllm_config.offload_config
self
.
compilation_config
=
vllm_config
.
compilation_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
...
...
@@ -417,6 +426,9 @@ class GPUModelRunner(
# Sampler
self
.
sampler
=
Sampler
(
logprobs_mode
=
self
.
model_config
.
logprobs_mode
)
# multi layer eagle
self
.
enable_multi_layer_eagle
=
False
self
.
eplb_state
:
EplbState
|
None
=
None
"""
State of the expert parallelism load balancer.
...
...
@@ -439,6 +451,9 @@ class GPUModelRunner(
self
.
encoder_cache
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
use_aux_hidden_state_outputs
=
False
self
.
multi_layer_eagle_num
=
0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
...
...
@@ -450,6 +465,7 @@ class GPUModelRunner(
|
EagleProposer
|
DraftModelProposer
|
MedusaProposer
|
ExtractHiddenStatesProposer
)
if
self
.
speculative_config
.
method
==
"ngram"
:
self
.
drafter
=
NgramProposer
(
self
.
vllm_config
)
...
...
@@ -462,7 +478,19 @@ class GPUModelRunner(
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
if
(
self
.
speculative_config
.
enable_multi_layers_mtp
and
self
.
speculative_config
.
method
==
"mtp"
):
self
.
enable_multi_layer_eagle
=
True
self
.
drafter
=
MultiLayerEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
multi_layer_eagle_num
=
self
.
drafter
.
layer_num
else
:
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
# self.drafter = EagleProposer(self.vllm_config, self.device, self)
if
self
.
speculative_config
.
method
==
"eagle3"
:
self
.
use_aux_hidden_state_outputs
=
(
self
.
drafter
.
eagle3_use_aux_hidden_state
...
...
@@ -471,6 +499,11 @@ class GPUModelRunner(
self
.
drafter
=
MedusaProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
)
elif
self
.
speculative_config
.
method
==
"extract_hidden_states"
:
self
.
drafter
=
ExtractHiddenStatesProposer
(
vllm_config
=
self
.
vllm_config
,
device
=
self
.
device
)
self
.
use_aux_hidden_state_outputs
=
True
else
:
raise
ValueError
(
"Unknown speculative decoding method: "
...
...
@@ -535,6 +568,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids
=
bool
(
custom_logitsprocs
),
is_pooling_model
=
self
.
is_pooling_model
,
cp_kv_cache_interleave_size
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
...
...
@@ -623,6 +660,7 @@ class GPUModelRunner(
(
3
,
self
.
max_num_tokens
+
1
),
dtype
=
torch
.
int64
)
# Only relevant for models using XD-RoPE (e.g, HunYuan-VL)
if
self
.
uses_xdrope_dim
>
0
:
# Similar to mrope but use assigned dimension number for RoPE, 4 as default.
...
...
@@ -805,7 +843,6 @@ class GPUModelRunner(
pin_memory
=
self
.
pin_memory
,
with_numpy
=
numpy
,
)
def
_copy_mrope_positions_to_gpu
(
self
,
num_tokens
:
int
)
->
None
:
if
not
self
.
uses_mrope
:
return
...
...
@@ -816,6 +853,7 @@ class GPUModelRunner(
non_blocking
=
True
,
)
return
self
.
mrope_positions
.
gpu
[:,
:
num_tokens
].
copy_
(
self
.
mrope_positions
.
cpu
[:,
:
num_tokens
],
non_blocking
=
True
,
...
...
@@ -1014,6 +1052,9 @@ class GPUModelRunner(
if
self
.
uses_xdrope_dim
>
0
:
self
.
_init_xdrope_positions
(
req_state
)
if
self
.
enable_multi_layer_eagle
:
self
.
_init_multi_layer_eagle_cache
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Update the states of the running/resumed requests.
...
...
@@ -1265,6 +1306,24 @@ class GPUModelRunner(
req_state
.
mm_features
,
)
def
_init_multi_layer_eagle_cache
(
self
,
req_state
:
CachedRequestState
):
req_state
.
cached_len
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_hidden_states
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
self
.
model_config
.
get_hidden_size
(),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
req_state
.
cached_token_ids
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
req_state
.
cached_positions
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_slot_mappings
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
_extract_mm_kwargs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1689,6 +1748,17 @@ class GPUModelRunner(
self
.
num_decode_draft_tokens
.
np
[
num_reqs
:].
fill
(
-
1
)
self
.
num_decode_draft_tokens
.
copy_to_gpu
()
if
self
.
enable_multi_layer_eagle
:
multi_layer_eagle_metadata
=
MultiLayerEagleMetadata
(
cached_len
=
self
.
input_batch
.
cached_len
[:
num_reqs
],
cached_token_ids
=
self
.
input_batch
.
cached_token_ids
[:
num_reqs
],
cached_hidden_states
=
self
.
input_batch
.
cached_hidden_states
[:
num_reqs
],
cached_slot_mappings
=
self
.
input_batch
.
cached_slot_mappings
[:
num_reqs
],
cached_positions
=
self
.
input_batch
.
cached_positions
[:
num_reqs
],
)
else
:
multi_layer_eagle_metadata
=
None
# Hot-Swap lora model
if
self
.
lora_config
:
assert
(
...
...
@@ -1699,10 +1769,11 @@ class GPUModelRunner(
self
.
input_batch
,
num_scheduled_tokens
,
num_sampled_tokens
)
return
(
logits_indices
,
spec_decode_metadata
,
)
# return (
# logits_indices,
# spec_decode_metadata,
# )
return
(
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
)
def
_build_attention_metadata
(
self
,
...
...
@@ -2168,9 +2239,9 @@ class GPUModelRunner(
req
.
mrope_positions
[:,
src_start
:
src_end
].
transpose
(
0
,
1
)
)
else
:
self
.
mrope_positions
.
cpu
[:,
dst_start
:
dst_end
]
=
(
req
.
mrope_positions
[
:,
src_start
:
src_end
]
)
self
.
mrope_positions
.
cpu
[:,
dst_start
:
dst_end
]
=
req
.
mrope_positions
[
:,
src_start
:
src_end
]
mrope_pos_ptr
+=
prompt_part_len
if
completion_part_len
>
0
:
...
...
@@ -2181,9 +2252,7 @@ class GPUModelRunner(
assert
req
.
mrope_position_delta
is
not
None
if
self
.
use_1d_mrope
:
values
=
np
.
arange
(
req
.
mrope_position_delta
+
num_computed_tokens
+
prompt_part_len
,
req
.
mrope_position_delta
+
num_computed_tokens
+
prompt_part_len
,
req
.
mrope_position_delta
+
num_computed_tokens
+
prompt_part_len
...
...
@@ -3457,10 +3526,16 @@ class GPUModelRunner(
max_num_scheduled_tokens
=
int
(
num_scheduled_tokens_np
.
max
())
num_tokens_unpadded
=
scheduler_output
.
total_num_scheduled_tokens
logits_indices
,
spec_decode_metadata
=
self
.
_prepare_inputs
(
# logits_indices, spec_decode_metadata = self._prepare_inputs(
# scheduler_output,
# num_scheduled_tokens_np,
# )
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
=
(
self
.
_prepare_inputs
(
scheduler_output
,
num_scheduled_tokens_np
,
)
)
cascade_attn_prefix_lens
=
None
# Disable cascade attention when using microbatching (DBO)
...
...
@@ -3683,6 +3758,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3720,6 +3796,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3759,6 +3836,7 @@ class GPUModelRunner(
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
slot_mappings
,
)
...
...
@@ -3959,6 +4037,233 @@ class GPUModelRunner(
sampled_count_event
.
synchronize
()
return
counts_cpu
[:
prev_sampled_token_ids
.
shape
[
0
]].
tolist
()
# def propose_draft_token_ids(
# self,
# scheduler_output: "SchedulerOutput",
# sampled_token_ids: torch.Tensor | list[list[int]],
# sampling_metadata: SamplingMetadata,
# hidden_states: torch.Tensor,
# sample_hidden_states: torch.Tensor,
# aux_hidden_states: list[torch.Tensor] | None,
# spec_decode_metadata: SpecDecodeMetadata | None,
# # multi_layer_eagle_metadata: MultiLayerEagleMetadata | None,
# common_attn_metadata: CommonAttentionMetadata,
# slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
# ) -> list[list[int]] | torch.Tensor:
# num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# spec_config = self.speculative_config
# assert spec_config is not None
# if spec_config.method == "ngram":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, NgramProposer)
# draft_token_ids = self.drafter.propose(
# sampled_token_ids,
# self.input_batch.num_tokens_no_spec,
# self.input_batch.token_ids_cpu,
# slot_mappings=slot_mappings,
# )
# elif spec_config.method == "suffix":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, SuffixDecodingProposer)
# draft_token_ids = self.drafter.propose(
# self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
# )
# elif spec_config.method == "medusa":
# assert isinstance(sampled_token_ids, list)
# assert isinstance(self.drafter, MedusaProposer)
# if sample_hidden_states.shape[0] == len(sampled_token_ids):
# # The input to the target model does not include draft tokens.
# hidden_states = sample_hidden_states
# else:
# indices = []
# offset = 0
# assert spec_decode_metadata is not None, (
# "No spec decode metadata for medusa"
# )
# for num_draft, tokens in zip(
# spec_decode_metadata.num_draft_tokens, sampled_token_ids
# ):
# indices.append(offset + len(tokens) - 1)
# offset += num_draft + 1
# indices = torch.tensor(indices, device=self.device)
# hidden_states = sample_hidden_states[indices]
# draft_token_ids = self.drafter.propose(
# target_hidden_states=hidden_states,
# sampling_metadata=sampling_metadata,
# slot_mappings=slot_mappings,
# )
# elif spec_config.uses_extract_hidden_states():
# assert isinstance(self.drafter, ExtractHiddenStatesProposer)
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor for "
# "extract_hidden_states method."
# )
# if not self.use_aux_hidden_state_outputs or aux_hidden_states is None:
# raise ValueError(
# "aux_hidden_states are required when using `extract_hidden_states`"
# )
# target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states]
# draft_token_ids, drafter_kv_connector_output = self.drafter.propose(
# sampled_token_ids=sampled_token_ids,
# target_hidden_states=target_hidden_states,
# common_attn_metadata=common_attn_metadata,
# scheduler_output=scheduler_output,
# slot_mappings=slot_mappings,
# )
# # Combine KVConnectorOutputs or select the non-empty one
# if self.kv_connector_output and drafter_kv_connector_output:
# self.kv_connector_output = KVConnectorOutput.merge(
# self.kv_connector_output, drafter_kv_connector_output
# )
# else:
# self.kv_connector_output = (
# self.kv_connector_output or drafter_kv_connector_output
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# elif spec_config.use_eagle() or spec_config.uses_draft_model():
# assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
# if spec_config.disable_padded_drafter_batch:
# # When padded-batch is disabled, the sampled_token_ids should be
# # the cpu-side list[list[int]] of valid sampled tokens for each
# # request, with invalid requests having empty lists.
# assert isinstance(sampled_token_ids, list), (
# "sampled_token_ids should be a python list when"
# "padded-batch is disabled."
# )
# next_token_ids = self.drafter.prepare_next_token_ids_cpu(
# sampled_token_ids,
# self.requests,
# self.input_batch,
# scheduler_output.num_scheduled_tokens,
# )
# else:
# # When using padded-batch, the sampled_token_ids should be
# # the gpu tensor of sampled tokens for each request, of shape
# # (num_reqs, num_spec_tokens + 1) with rejected tokens having
# # value -1.
# assert isinstance(sampled_token_ids, torch.Tensor), (
# "sampled_token_ids should be a torch.Tensor when"
# "padded-batch is enabled."
# )
# next_token_ids, valid_sampled_tokens_count = (
# self.drafter.prepare_next_token_ids_padded(
# common_attn_metadata,
# sampled_token_ids,
# self.requests,
# self.input_batch,
# self.discard_request_mask.gpu,
# )
# )
# self._copy_valid_sampled_token_count(
# next_token_ids, valid_sampled_tokens_count
# )
# num_rejected_tokens_gpu = None
# if spec_decode_metadata is None:
# token_indices_to_sample = None
# # input_ids can be None for multimodal models.
# target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# target_positions = self._get_positions(num_scheduled_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:num_scheduled_tokens]
# else:
# if spec_config.disable_padded_drafter_batch:
# token_indices_to_sample = None
# common_attn_metadata, token_indices = self.drafter.prepare_inputs(
# common_attn_metadata,
# sampled_token_ids,
# spec_decode_metadata.num_draft_tokens,
# )
# target_token_ids = self.input_ids.gpu[token_indices]
# target_positions = self._get_positions(token_indices)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[token_indices] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[token_indices]
# else:
# (
# common_attn_metadata,
# token_indices_to_sample,
# num_rejected_tokens_gpu,
# ) = self.drafter.prepare_inputs_padded(
# common_attn_metadata,
# spec_decode_metadata,
# valid_sampled_tokens_count,
# )
# total_num_tokens = common_attn_metadata.num_actual_tokens
# # When padding the batch, token_indices is just a range
# target_token_ids = self.input_ids.gpu[:total_num_tokens]
# target_positions = self._get_positions(total_num_tokens)
# if self.use_aux_hidden_state_outputs:
# assert aux_hidden_states is not None
# target_hidden_states = torch.cat(
# [h[:total_num_tokens] for h in aux_hidden_states], dim=-1
# )
# else:
# target_hidden_states = hidden_states[:total_num_tokens]
# # if self.supports_mm_inputs:
# if self.supports_mm_inputs and self.drafter.supports_mm_inputs:
# mm_embed_inputs = self._gather_mm_embeddings(
# scheduler_output,
# shift_computed_tokens=1,
# )
# else:
# mm_embed_inputs = None
# draft_result = self.drafter.propose(
# target_token_ids=target_token_ids,
# target_positions=target_positions,
# target_hidden_states=target_hidden_states,
# next_token_ids=next_token_ids,
# token_indices_to_sample=token_indices_to_sample,
# sampling_metadata=sampling_metadata,
# common_attn_metadata=common_attn_metadata,
# mm_embed_inputs=mm_embed_inputs,
# num_rejected_tokens_gpu=num_rejected_tokens_gpu,
# slot_mappings=slot_mappings,
# # multi_layer_eagle_metadata=multi_layer_eagle_metadata,
# )
# if not envs.VLLM_REJECT_SAMPLE_OPT:
# draft_token_ids = draft_result
# else:
# draft_token_ids, draft_probs = draft_result
# if envs.VLLM_REJECT_SAMPLE_OPT:
# draft_req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# if self.draft_probs is None:
# self.draft_probs = DraftProbs(
# draft_probs, draft_req_ids)
# else:
# self.draft_probs.update(draft_probs, draft_req_ids)
# return draft_token_ids
def
propose_draft_token_ids
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -3968,6 +4273,7 @@ class GPUModelRunner(
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
spec_decode_metadata
:
SpecDecodeMetadata
|
None
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
,
)
->
list
[
list
[
int
]]
|
torch
.
Tensor
:
...
...
@@ -3975,6 +4281,8 @@ class GPUModelRunner(
spec_config
=
self
.
speculative_config
assert
spec_config
is
not
None
if
spec_config
.
method
==
"ngram"
:
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
NgramProposer
)
draft_token_ids
=
self
.
drafter
.
propose
(
...
...
@@ -3983,6 +4291,15 @@ class GPUModelRunner(
self
.
input_batch
.
token_ids_cpu
,
slot_mappings
=
slot_mappings
,
)
if
isinstance
(
self
.
drafter
,
NgramProposer
):
assert
isinstance
(
sampled_token_ids
,
list
),
(
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids
=
self
.
drafter
.
propose
(
sampled_token_ids
,
self
.
input_batch
.
num_tokens_no_spec
,
self
.
input_batch
.
token_ids_cpu
,
)
elif
spec_config
.
method
==
"suffix"
:
assert
isinstance
(
sampled_token_ids
,
list
)
assert
isinstance
(
self
.
drafter
,
SuffixDecodingProposer
)
...
...
@@ -4015,6 +4332,48 @@ class GPUModelRunner(
sampling_metadata
=
sampling_metadata
,
slot_mappings
=
slot_mappings
,
)
elif
spec_config
.
uses_extract_hidden_states
():
assert
isinstance
(
self
.
drafter
,
ExtractHiddenStatesProposer
)
assert
isinstance
(
sampled_token_ids
,
torch
.
Tensor
),
(
"sampled_token_ids should be a torch.Tensor for "
"extract_hidden_states method."
)
if
not
self
.
use_aux_hidden_state_outputs
or
aux_hidden_states
is
None
:
raise
ValueError
(
"aux_hidden_states are required when using `extract_hidden_states`"
)
target_hidden_states
=
[
h
[:
num_scheduled_tokens
]
for
h
in
aux_hidden_states
]
draft_token_ids
,
drafter_kv_connector_output
=
self
.
drafter
.
propose
(
sampled_token_ids
=
sampled_token_ids
,
target_hidden_states
=
target_hidden_states
,
common_attn_metadata
=
common_attn_metadata
,
scheduler_output
=
scheduler_output
,
slot_mappings
=
slot_mappings
,
)
# Combine KVConnectorOutputs or select the non-empty one
if
self
.
kv_connector_output
and
drafter_kv_connector_output
:
self
.
kv_connector_output
=
KVConnectorOutput
.
merge
(
self
.
kv_connector_output
,
drafter_kv_connector_output
)
else
:
self
.
kv_connector_output
=
(
self
.
kv_connector_output
or
drafter_kv_connector_output
)
next_token_ids
,
valid_sampled_tokens_count
=
(
self
.
drafter
.
prepare_next_token_ids_padded
(
common_attn_metadata
,
sampled_token_ids
,
self
.
requests
,
self
.
input_batch
,
self
.
discard_request_mask
.
gpu
,
)
)
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
)
elif
spec_config
.
use_eagle
()
or
spec_config
.
uses_draft_model
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
|
DraftModelProposer
)
...
...
@@ -4106,7 +4465,7 @@ class GPUModelRunner(
else
:
target_hidden_states
=
hidden_states
[:
total_num_tokens
]
if
self
.
supports_mm_inputs
:
if
self
.
supports_mm_inputs
and
self
.
drafter
.
supports_mm_inputs
:
mm_embed_inputs
=
self
.
_gather_mm_embeddings
(
scheduler_output
,
shift_computed_tokens
=
1
,
...
...
@@ -4119,28 +4478,16 @@ class GPUModelRunner(
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
next_token_ids
=
next_token_ids
,
last_
token_indices
=
token_indices_to_sample
,
token_indices
_to_sample
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
mm_embed_inputs
=
mm_embed_inputs
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
slot_mappings
=
slot_mappings
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
)
if
not
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_token_ids
=
draft_result
else
:
draft_token_ids
,
draft_probs
=
draft_result
if
envs
.
VLLM_REJECT_SAMPLE_OPT
:
draft_req_ids
=
list
(
scheduler_output
.
num_scheduled_tokens
.
keys
())
if
self
.
draft_probs
is
None
:
self
.
draft_probs
=
DraftProbs
(
draft_probs
,
draft_req_ids
)
else
:
self
.
draft_probs
.
update
(
draft_probs
,
draft_req_ids
)
return
draft_token_ids
return
draft_result
def
update_config
(
self
,
overrides
:
dict
[
str
,
Any
])
->
None
:
allowed_config_names
=
{
"load_config"
,
"model_config"
}
...
...
@@ -5709,6 +6056,8 @@ class GPUModelRunner(
logitsprocs
=
self
.
input_batch
.
logitsprocs
,
logitsprocs_need_output_token_ids
=
self
.
input_batch
.
logitsprocs_need_output_token_ids
,
is_pooling_model
=
self
.
is_pooling_model
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
def
_allocate_kv_cache_tensors
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment