Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
84ac1d27
Commit
84ac1d27
authored
May 09, 2026
by
zhangning3
Browse files
add step3p5 mtp3
parent
b27f1671
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
3486 additions
and
1495 deletions
+3486
-1495
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+14
-1
step3p5_benchmark_test.sh
step3p5_benchmark_test.sh
+242
-0
tests/tool_parsers/test_step3p5_tool_parser.py
tests/tool_parsers/test_step3p5_tool_parser.py
+166
-68
tests/v1/spec_decode/test_mtp3.py
tests/v1/spec_decode/test_mtp3.py
+961
-0
vllm/config/speculative.py
vllm/config/speculative.py
+8
-1
vllm/entrypoints/anthropic/serving.py
vllm/entrypoints/anthropic/serving.py
+122
-31
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+6
-3
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+1
-0
vllm/model_executor/models/step3p5_mtp.py
vllm/model_executor/models/step3p5_mtp.py
+13
-7
vllm/tool_parsers/abstract_tool_parser.py
vllm/tool_parsers/abstract_tool_parser.py
+6
-0
vllm/tool_parsers/step3p5_tool_parser.py
vllm/tool_parsers/step3p5_tool_parser.py
+994
-1351
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+1
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+32
-14
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-1
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+44
-0
vllm/v1/spec_decode/multi_layer_eagle.py
vllm/v1/spec_decode/multi_layer_eagle.py
+701
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+87
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+72
-6
vllm/version.py
vllm/version.py
+14
-11
No files found.
examples/offline_inference/spec_decode.py
View file @
84ac1d27
...
...
@@ -57,6 +57,11 @@ def parse_args():
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--enable-multi-layers-mtp"
,
action
=
"store_true"
,
help
=
"Enable multi-layer MTP (only effective when --method=mtp)."
,
)
parser
.
add_argument
(
"--max-model-len"
,
type
=
int
,
default
=
16384
)
parser
.
add_argument
(
"--temp"
,
type
=
float
,
default
=
0
)
parser
.
add_argument
(
"--top-p"
,
type
=
float
,
default
=
1.0
)
...
...
@@ -66,12 +71,14 @@ def parse_args():
parser
.
add_argument
(
"--model-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--eagle-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--draft-model"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--tokenizer-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--custom-mm-prompts"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--parallel-drafting"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--allowed-local-media-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
...
...
@@ -85,7 +92,11 @@ def main(args):
"please specify model_dir to give a mm based model"
)
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_dir
)
tokenizer_dir
=
args
.
tokenizer_dir
if
tokenizer_dir
is
None
:
tokenizer_dir
=
model_dir
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_dir
)
if
args
.
custom_mm_prompts
:
prompts
=
llm_prompts
=
get_custom_mm_prompts
(
args
.
num_prompts
)
...
...
@@ -141,6 +152,8 @@ def main(args):
"method"
:
"mtp"
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
}
if
args
.
enable_multi_layers_mtp
:
speculative_config
[
"enable_multi_layers_mtp"
]
=
True
else
:
raise
ValueError
(
f
"unknown method:
{
args
.
method
}
"
)
...
...
step3p5_benchmark_test.sh
0 → 100644
View file @
84ac1d27
#!/usr/bin/env bash
set
-u
DEFAULT_BATCH_SIZES
=(
1 8 16 32 64 128
)
MODEL_PATH
=
"/module/step3.5-fp8/"
SERVED_MODEL_NAME
=
"/module/step3.5-fp8/"
DATASET_NAME
=
"random"
DEFAULT_OUTPUT_LEN_DECODE
=
4096
DEFAULT_OUTPUT_LEN_PREFILL
=
1
DEFAULT_ROLE
=
"decode"
READY_CHECK_TIMEOUT
=
3
RESULT_DIR
=
"benchmark_result"
print_usage
()
{
cat
<<
'
USAGE
'
Usage:
./scripts/step3p5_benchmark_test.sh
./scripts/step3p5_benchmark_test.sh 1,8,16,32
./scripts/step3p5_benchmark_test.sh --role prefill
./scripts/step3p5_benchmark_test.sh --role decode
./scripts/step3p5_benchmark_test.sh --role both
./scripts/step3p5_benchmark_test.sh --role prefill --output-len 1
Description:
- No argument: use default batch sizes, role=decode, output-len=4096
- Optional positional argument: batch size list (comma or space separated)
- Optional flag: --role <prefill|decode|both>
- Optional flag: --output-len <N> (must be positive integer)
- role=both 时串行执行 prefill 再 decode
- Result files are saved under:
<result_dir>/prefill (when role=prefill)
<result_dir>/decode (when role=decode)
USAGE
}
parse_batch_sizes
()
{
local
raw_input
=
"
${
1
:-}
"
if
[[
-z
"
$raw_input
"
]]
;
then
BATCH_SIZES
=(
"
${
DEFAULT_BATCH_SIZES
[@]
}
"
)
return
fi
raw_input
=
"
${
raw_input
//,/
}
"
read
-r
-a
BATCH_SIZES
<<<
"
$raw_input
"
if
((
${#
BATCH_SIZES
[@]
}
==
0
))
;
then
echo
"[ERROR] batch size 列表为空。"
print_usage
exit
1
fi
local
batch_size
for
batch_size
in
"
${
BATCH_SIZES
[@]
}
"
;
do
if
!
[[
"
$batch_size
"
=
~ ^[1-9][0-9]
*
$
]]
;
then
echo
"[ERROR] 非法 batch size:
$batch_size
(必须是正整数)"
exit
1
fi
done
}
parse_role
()
{
local
role_input
=
"
${
1
:-
$DEFAULT_ROLE
}
"
if
[[
"
$role_input
"
!=
"prefill"
&&
"
$role_input
"
!=
"decode"
&&
"
$role_input
"
!=
"both"
]]
;
then
echo
"[ERROR] 非法 role:
$role_input
(必须是 prefill、decode 或 both)"
print_usage
exit
1
fi
ROLE
=
"
$role_input
"
}
parse_output_len
()
{
local
output_len_input
=
"
$1
"
if
!
[[
"
$output_len_input
"
=
~ ^[1-9][0-9]
*
$
]]
;
then
echo
"[ERROR] 非法 output-len:
$output_len_input
(必须是正整数)"
print_usage
exit
1
fi
RANDOM_OUTPUT_LEN
=
"
$output_len_input
"
}
calc_num_prompts
()
{
local
batch_size
=
"
$1
"
local
num_prompts
=
$((
batch_size
+
batch_size
/
2
))
if
((
num_prompts < 16
))
;
then
num_prompts
=
16
fi
if
((
num_prompts
>
384
))
;
then
num_prompts
=
384
fi
if
((
num_prompts < batch_size
))
;
then
num_prompts
=
$batch_size
fi
echo
"
$num_prompts
"
}
main
()
{
local
batch_arg
=
""
local
role_arg
=
"
$DEFAULT_ROLE
"
local
output_len_arg
=
""
while
((
$#
>
0
))
;
do
case
"
$1
"
in
-h
|
--help
)
print_usage
exit
0
;;
--role
|
-r
)
if
[[
-z
"
${
2
:-}
"
]]
;
then
echo
"[ERROR] --role 缺少参数。"
print_usage
exit
1
fi
role_arg
=
"
$2
"
shift
2
;;
--output-len
|
-o
)
if
[[
-z
"
${
2
:-}
"
]]
;
then
echo
"[ERROR] --output-len 缺少参数。"
print_usage
exit
1
fi
output_len_arg
=
"
$2
"
shift
2
;;
--
*
)
echo
"[ERROR] 未知参数:
$1
"
print_usage
exit
1
;;
*
)
if
[[
-n
"
$batch_arg
"
]]
;
then
echo
"[ERROR] 仅支持一个 batch size 列表参数。"
print_usage
exit
1
fi
batch_arg
=
"
$1
"
shift
;;
esac
done
parse_batch_sizes
"
$batch_arg
"
parse_role
"
$role_arg
"
if
[[
"
$ROLE
"
==
"both"
]]
;
then
local
-a
prefill_cmd
=(
"
$0
"
)
local
-a
decode_cmd
=(
"
$0
"
)
if
[[
-n
"
$batch_arg
"
]]
;
then
prefill_cmd+
=(
"
$batch_arg
"
)
decode_cmd+
=(
"
$batch_arg
"
)
fi
prefill_cmd+
=(
"--role"
"prefill"
)
decode_cmd+
=(
"--role"
"decode"
)
if
[[
-n
"
$output_len_arg
"
]]
;
then
decode_cmd+
=(
"--output-len"
"
$output_len_arg
"
)
fi
echo
"[INFO] role=both: 将串行执行 prefill 和 decode"
echo
"[INFO] step1:
${
prefill_cmd
[*]
}
"
"
${
prefill_cmd
[@]
}
"
echo
"[INFO] step2:
${
decode_cmd
[*]
}
"
"
${
decode_cmd
[@]
}
"
echo
"[INFO] role=both 执行完成。"
return
0
fi
if
[[
"
$ROLE
"
==
"prefill"
]]
;
then
if
[[
-n
"
$output_len_arg
"
&&
"
$output_len_arg
"
!=
"
$DEFAULT_OUTPUT_LEN_PREFILL
"
]]
;
then
echo
"[WARN] role=prefill 时 output-len 必须为 1,已自动覆盖为 1。"
fi
output_len_arg
=
"
$DEFAULT_OUTPUT_LEN_PREFILL
"
elif
[[
-z
"
$output_len_arg
"
]]
;
then
output_len_arg
=
"
$DEFAULT_OUTPUT_LEN_DECODE
"
fi
parse_output_len
"
$output_len_arg
"
RESULT_SUBDIR
=
"
$RESULT_DIR
/
$ROLE
"
mkdir
-p
"
$RESULT_SUBDIR
"
echo
"[INFO] 将执行
${#
BATCH_SIZES
[@]
}
组 benchmark"
echo
"[INFO] role:
$ROLE
"
echo
"[INFO] random-output-len:
$RANDOM_OUTPUT_LEN
"
echo
"[INFO] batch size 列表:
${
BATCH_SIZES
[*]
}
"
echo
"[INFO] result_dir:
$RESULT_SUBDIR
"
local
batch_size
local
num_prompts
local
failed_count
=
0
for
batch_size
in
"
${
BATCH_SIZES
[@]
}
"
;
do
num_prompts
=
"
$(
calc_num_prompts
"
$batch_size
"
)
"
echo
""
echo
"[INFO] 开始测试: role=
$ROLE
, batch_size=
$batch_size
, max_concurrency=
$batch_size
, num_prompts=
$num_prompts
, random_output_len=
$RANDOM_OUTPUT_LEN
"
if
!
vllm bench serve
\
--backend
vllm
\
--model
"
$MODEL_PATH
"
\
--served-model-name
"
$SERVED_MODEL_NAME
"
\
--dataset-name
"
$DATASET_NAME
"
\
--random-input-len
65536
\
--random-output-len
"
$RANDOM_OUTPUT_LEN
"
\
--num-prompts
"
$num_prompts
"
\
--temperature
0
\
--max-concurrency
"
$batch_size
"
\
--ready-check-timeout
"
$READY_CHECK_TIMEOUT
"
\
--result-dir
"
$RESULT_SUBDIR
"
\
--port
8018
\
--save-result
;
then
echo
"[WARN] batch_size=
$batch_size
执行失败,继续下一个。"
failed_count
=
$((
failed_count
+
1
))
else
echo
"[INFO] batch_size=
$batch_size
执行完成。"
fi
done
echo
""
if
((
failed_count
>
0
))
;
then
echo
"[WARN] 全部执行结束,但有
$failed_count
组失败。"
exit
1
fi
echo
"[INFO] 全部 benchmark 执行完成。"
}
main
"
$@
"
tests/tool_parsers/test_step3p5_tool_parser.py
View file @
84ac1d27
...
...
@@ -386,6 +386,57 @@ TX
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
def
test_extract_tool_calls_missing_function_name
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call without function name should fallback to content."""
model_output
=
(
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_empty_function_name
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call with empty function name should fallback to content."""
model_output
=
(
"<tool_call><function=><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_empty_function_name_single_quotes
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call with empty function name (single quotes) should fallback."""
model_output
=
(
"<tool_call><function=''><parameter=pattern>*.py</parameter></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_type_conversion
(
step3p5_tool_parser
):
"""Test parameter type conversion based on tool schema"""
tools
=
[
...
...
@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming(
expected_tool_calls
,
expected_content
,
):
"""Test
incremental streaming behavior including typed parameters
"""
"""Test
streaming returns complete tool calls when parsed.
"""
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
""
...
...
@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming(
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
# First chunk should have id, name, and type
if
tool_call
.
id
:
tool_states
[
idx
][
"id"
]
=
tool_call
.
id
...
...
@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming(
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
# Accumulate arguments incrementally
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
# Arguments should be complete JSON when emitted.
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Verify final content
assert
other_content
==
(
expected_content
or
""
)
# Handle None case
...
...
@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming(
assert
state
[
"type"
]
==
"function"
assert
state
[
"name"
]
==
expected_tool
.
function
.
name
# Parse
accumulated
arguments
# Parse arguments
arguments_str
=
state
[
"arguments"
]
assert
arguments_str
is
not
None
actual_args
=
json
.
loads
(
arguments_str
)
...
...
@@ -770,7 +827,7 @@ fahrenheit
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -786,7 +843,14 @@ fahrenheit
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Verify content was streamed
assert
"Let me check the weather for you:"
in
other_content
...
...
@@ -806,62 +870,69 @@ fahrenheit
assert
args
[
"unit"
]
==
"fahrenheit"
def
test_extract_tool_calls_streaming_
incremental
(
def
test_extract_tool_calls_streaming_
missing_function_name
(
step3p5_tool_parser
,
step3p5_tokenizer
,
sample_tools
):
"""Test that streaming is truly incremental"""
model_output
=
"""I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
"""Streaming: missing function name should be treated as content."""
model_output
=
(
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
chunks
=
[]
for
delta_message
in
stream_delta_message_generator
(
step3p5_tool_parser
,
step3p5_tokenizer
,
model_output
,
request
other_content
=
""
tool_calls
=
[]
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
[
"<tool_call><parameter=pattern>"
,
"*.py</parameter>"
,
"</function></tool_call>"
,
],
request
,
):
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
tool_calls
.
extend
(
delta_message
.
tool_calls
)
assert
other_content
==
model_output
assert
tool_calls
==
[]
def
test_extract_tool_calls_streaming_empty_function_name
(
step3p5_tool_parser
,
step3p5_tokenizer
,
sample_tools
):
"""Streaming: empty function name should be treated as content."""
model_output
=
(
"<tool_call><function=''><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
""
tool_calls
=
[]
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
[
"<tool_call><function="
,
"''><parameter=pattern>*.py</parameter>"
,
"</function></tool_call>"
,
],
request
,
):
chunks
.
append
(
delta_message
)
# Should have multiple chunks
assert
len
(
chunks
)
>
3
# First chunk(s) should be content
assert
chunks
[
0
].
content
is
not
None
assert
chunks
[
0
].
tool_calls
is
None
or
chunks
[
0
].
tool_calls
==
[]
# Should have a chunk with tool header (id, name, type)
header_found
=
False
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
id
:
header_found
=
True
assert
chunk
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
assert
chunk
.
tool_calls
[
0
].
type
==
"function"
# Empty initially
assert
chunk
.
tool_calls
[
0
].
function
.
arguments
==
""
break
assert
header_found
# Should have chunks with incremental arguments
arg_chunks
=
[]
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
function
.
arguments
:
arg_chunks
.
append
(
chunk
.
tool_calls
[
0
].
function
.
arguments
)
# Arguments should be streamed incrementally
assert
len
(
arg_chunks
)
>
1
# Concatenated arguments should form valid JSON
full_args
=
""
.
join
(
arg_chunks
)
parsed_args
=
json
.
loads
(
full_args
)
assert
parsed_args
[
"city"
]
==
"Dallas"
assert
parsed_args
[
"state"
]
==
"TX"
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
tool_calls
.
extend
(
delta_message
.
tool_calls
)
assert
other_content
==
model_output
assert
tool_calls
==
[]
def
test_extract_tool_calls_complex_type_with_single_quote
(
step3p5_tool_parser
):
...
...
@@ -951,7 +1022,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -967,7 +1038,14 @@ rectangle
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1164,7 +1242,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
if
tool_call
.
id
:
...
...
@@ -1175,7 +1253,14 @@ rectangle
if
tool_call
.
function
.
name
:
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1266,7 +1351,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -1282,7 +1367,14 @@ rectangle
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1344,20 +1436,26 @@ rectangle""",
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
delta_text_chunks
,
request
):
print
(
delta_message
)
if
delta_message
.
tool_calls
:
for
tool_call
in
delta_message
.
tool_calls
:
idx
=
tool_call
.
index
if
idx
not
in
tool_states
:
tool_states
[
idx
]
=
{
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
}
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
assert
len
(
tool_states
)
==
2
assert
all
(
state
[
"name"
]
for
state
in
tool_states
.
values
())
...
...
@@ -1368,7 +1466,7 @@ rectangle""",
def
test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between
(
step3p5_tool_parser
,
sample_tools
):
"""Test non-streaming extraction with tool calls
and no content between them
.
"""Test non-streaming extraction with
multiple
tool calls.
Scenario: Model outputs "hello" + tool call + tool call.
Expected: "hello" as content, first tool call parsed (index=0),
...
...
tests/v1/spec_decode/test_mtp3.py
0 → 100644
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
types
import
SimpleNamespace
import
pytest
import
torch
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
from
vllm.v1.spec_decode.multi_layer_eagle
import
MultiLayerEagleProposer
HIDDEN_SIZE
=
3
def
_make_multi_layer_eagle_metadata
(
*
,
initial_cache
:
list
[
dict
],
max_shift
:
int
,
device
:
torch
.
device
,
)
->
MultiLayerEagleMetadata
:
for
row
in
initial_cache
:
assert
"len"
in
row
row_len
=
int
(
row
[
"len"
])
assert
0
<=
row_len
<=
max_shift
# Test cases pad cache rows to `layer_num` (== max_shift) and specify the
# number of valid entries via `len`.
assert
(
len
(
row
[
"token_ids"
])
==
len
(
row
[
"positions"
])
==
len
(
row
[
"slot_mapping"
])
==
max_shift
)
assert
all
(
v
==
0
for
v
in
row
[
"token_ids"
][
row_len
:])
assert
all
(
v
==
0
for
v
in
row
[
"positions"
][
row_len
:])
assert
all
(
v
==
0
for
v
in
row
[
"slot_mapping"
][
row_len
:])
cached_len
=
torch
.
tensor
(
[
min
(
int
(
row
[
"len"
]),
max_shift
)
for
row
in
initial_cache
],
dtype
=
torch
.
int64
,
device
=
device
,
)
cached_token_ids
=
torch
.
tensor
(
[
row
[
"token_ids"
]
for
row
in
initial_cache
],
dtype
=
torch
.
int32
,
device
=
device
,
)
cached_positions
=
torch
.
tensor
(
[
row
[
"positions"
]
for
row
in
initial_cache
],
dtype
=
torch
.
int64
,
device
=
device
,
)
cached_slot_mappings
=
torch
.
tensor
(
[
row
[
"slot_mapping"
]
for
row
in
initial_cache
],
dtype
=
torch
.
int64
,
device
=
device
,
)
cached_hidden_states
=
torch
.
zeros
(
(
len
(
initial_cache
),
max_shift
,
HIDDEN_SIZE
),
dtype
=
torch
.
float32
,
device
=
device
,
)
return
MultiLayerEagleMetadata
(
cached_len
=
cached_len
,
cached_token_ids
=
cached_token_ids
,
cached_hidden_states
=
cached_hidden_states
,
cached_slot_mappings
=
cached_slot_mappings
,
cached_positions
=
cached_positions
,
)
@
pytest
.
fixture
def
proposer_stub
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"MultiLayerEagleProposer.adjust_input is CUDA/Triton-only."
)
proposer
=
MultiLayerEagleProposer
.
__new__
(
MultiLayerEagleProposer
)
proposer
.
layer_num
=
3
return
proposer
LAYER3_CASES
=
[
{
"name"
:
"shift_0_at_sequence_end"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
],
"target_positions"
:
[
0
,
1
,
2
,
3
],
"token_indices_to_sample"
:
[
3
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
4
],
"seq_lens"
:
[
4
],
"seq_lens_cpu"
:
[
4
],
"num_computed_tokens_cpu"
:
[
0
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
],
"max_seq_len"
:
4
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
],
"prev_positions"
:
[
0
,
1
,
2
,
3
],
"token_indices_to_sample"
:
[
3
],
"seq_lens"
:
[
4
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
11
,
12
,
13
],
"positions"
:
[
1
,
2
,
3
],
"slot_mapping"
:
[
101
,
102
,
103
],
}
],
},
},
{
"name"
:
"batch2_short_seq_no_shift"
,
"batch_size"
:
2
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
],
"target_token_ids"
:
[
10
,
11
,
20
],
"target_positions"
:
[
0
,
1
,
0
],
"token_indices_to_sample"
:
[
1
,
2
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
2
,
3
],
"seq_lens"
:
[
2
,
1
],
"seq_lens_cpu"
:
[
2
,
1
],
"num_computed_tokens_cpu"
:
[
0
,
0
],
"slot_mapping"
:
[
100
,
101
,
200
],
"max_seq_len"
:
2
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
20
],
"prev_positions"
:
[
0
,
1
,
0
],
"token_indices_to_sample"
:
[
1
,
2
],
"seq_lens"
:
[
2
,
1
],
"slot_mapping"
:
[
100
,
101
,
200
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
10
,
11
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
},
{
"len"
:
1
,
"token_ids"
:
[
20
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
200
,
0
,
0
],
},
],
},
},
{
"name"
:
"batch2_short_seq_shift_on_first"
,
"batch_size"
:
2
,
"initial_cache"
:
[
{
"len"
:
1
,
"token_ids"
:
[
99
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
999
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
],
"target_token_ids"
:
[
10
,
11
,
20
],
"target_positions"
:
[
1
,
2
,
0
],
"token_indices_to_sample"
:
[
0
,
2
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
2
,
3
],
"seq_lens"
:
[
2
,
1
],
"seq_lens_cpu"
:
[
2
,
1
],
"num_computed_tokens_cpu"
:
[
1
,
0
],
"slot_mapping"
:
[
100
,
101
,
200
],
"max_seq_len"
:
2
,
},
"expected"
:
{
"prev_token_ids"
:
[
99
,
10
,
20
],
"prev_positions"
:
[
0
,
1
,
0
],
"token_indices_to_sample"
:
[
1
,
2
],
"seq_lens"
:
[
1
,
1
],
"slot_mapping"
:
[
999
,
100
,
200
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
99
,
10
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
999
,
100
,
0
],
},
{
"len"
:
1
,
"token_ids"
:
[
20
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
200
,
0
,
0
],
},
],
},
},
{
"name"
:
"short_seq_len_2_shift_0_cache_len_1"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
7
,
8
],
"target_positions"
:
[
0
,
1
],
"token_indices_to_sample"
:
[
0
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
2
],
"seq_lens"
:
[
2
],
"seq_lens_cpu"
:
[
2
],
"num_computed_tokens_cpu"
:
[
0
],
"slot_mapping"
:
[
1000
,
1001
],
"max_seq_len"
:
2
,
},
"expected"
:
{
"prev_token_ids"
:
[
7
,
8
],
"prev_positions"
:
[
0
,
1
],
"token_indices_to_sample"
:
[
0
],
"seq_lens"
:
[
2
],
"slot_mapping"
:
[
1000
,
1001
],
"cached"
:
[
{
"len"
:
1
,
"token_ids"
:
[
7
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
1000
,
0
,
0
],
}
],
},
},
{
"name"
:
"short_seq_len_2_shift_1_cache_len_2"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
1
,
"token_ids"
:
[
6
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
999
,
0
,
0
],
}
],
"target_token_ids"
:
[
7
,
8
],
"target_positions"
:
[
1
,
2
],
"token_indices_to_sample"
:
[
0
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
2
],
"seq_lens"
:
[
2
],
"seq_lens_cpu"
:
[
2
],
"num_computed_tokens_cpu"
:
[
1
],
"slot_mapping"
:
[
1000
,
1001
],
"max_seq_len"
:
2
,
},
"expected"
:
{
"prev_token_ids"
:
[
6
,
7
],
"prev_positions"
:
[
0
,
1
],
"token_indices_to_sample"
:
[
1
],
"seq_lens"
:
[
1
],
"slot_mapping"
:
[
999
,
1000
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
6
,
7
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
999
,
1000
,
0
],
}
],
},
},
{
"name"
:
"shift_bounded_by_start_pos_zero"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
],
"target_positions"
:
[
0
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
1
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
4
],
"seq_lens"
:
[
4
],
"seq_lens_cpu"
:
[
4
],
"num_computed_tokens_cpu"
:
[
0
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
],
"max_seq_len"
:
4
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
],
"prev_positions"
:
[
0
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
1
],
"seq_lens"
:
[
4
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
10
,
11
,
0
],
"positions"
:
[
0
,
2
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
}
],
},
},
{
"name"
:
"shift_bounded_by_start_pos"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
,
14
],
"target_positions"
:
[
0
,
1
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
1
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
5
],
"seq_lens"
:
[
5
],
"seq_lens_cpu"
:
[
5
],
"num_computed_tokens_cpu"
:
[
1
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"max_seq_len"
:
5
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
,
14
],
"prev_positions"
:
[
0
,
1
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
1
],
"seq_lens"
:
[
5
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
10
,
11
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
}
],
},
},
{
"name"
:
"shift_2_bounded_by_remaining"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
,
14
],
"target_positions"
:
[
0
,
1
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
2
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
5
],
"seq_lens"
:
[
5
],
"seq_lens_cpu"
:
[
5
],
"num_computed_tokens_cpu"
:
[
2
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"max_seq_len"
:
5
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
,
14
],
"prev_positions"
:
[
0
,
1
,
2
,
3
,
4
],
"token_indices_to_sample"
:
[
2
],
"seq_lens"
:
[
5
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
10
,
11
,
12
],
"positions"
:
[
0
,
1
,
2
],
"slot_mapping"
:
[
100
,
101
,
102
],
}
],
},
},
{
"name"
:
"shift_3_full_cache_window"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
}
],
"target_token_ids"
:
[
20
,
21
,
22
,
23
,
24
],
"target_positions"
:
[
0
,
3
,
4
,
5
,
6
],
"token_indices_to_sample"
:
[
1
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
5
],
"seq_lens"
:
[
5
],
"seq_lens_cpu"
:
[
5
],
"num_computed_tokens_cpu"
:
[
3
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"max_seq_len"
:
5
,
},
"expected"
:
{
"prev_token_ids"
:
[
20
,
21
,
22
,
23
,
24
],
"prev_positions"
:
[
0
,
3
,
4
,
5
,
6
],
"token_indices_to_sample"
:
[
1
],
"seq_lens"
:
[
5
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
20
,
21
,
0
],
"positions"
:
[
0
,
3
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
}
],
},
},
{
"name"
:
"batch2_shift_1_and_1"
,
"batch_size"
:
2
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
,
20
,
21
,
22
],
"target_positions"
:
[
0
,
1
,
2
,
3
,
0
,
1
,
2
],
"token_indices_to_sample"
:
[
1
,
5
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
4
,
7
],
"seq_lens"
:
[
4
,
3
],
"seq_lens_cpu"
:
[
4
,
3
],
"num_computed_tokens_cpu"
:
[
1
,
1
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
200
,
201
,
202
],
"max_seq_len"
:
4
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
,
20
,
21
,
22
],
"prev_positions"
:
[
0
,
1
,
2
,
3
,
0
,
1
,
2
],
"token_indices_to_sample"
:
[
1
,
5
],
"seq_lens"
:
[
4
,
3
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
200
,
201
,
202
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
10
,
11
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
},
{
"len"
:
2
,
"token_ids"
:
[
20
,
21
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
200
,
201
,
0
],
},
],
},
},
{
"name"
:
"batch4_mixed_shifts"
,
"batch_size"
:
4
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
1
,
"token_ids"
:
[
19
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
119
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
],
"target_token_ids"
:
[
10
,
11
,
20
,
21
,
22
,
30
,
31
,
32
,
33
,
40
,
41
,
42
],
"target_positions"
:
[
0
,
1
,
1
,
2
,
3
,
0
,
2
,
3
,
4
,
0
,
1
,
2
],
"token_indices_to_sample"
:
[
1
,
2
,
6
,
10
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
2
,
5
,
9
,
12
],
"seq_lens"
:
[
2
,
3
,
4
,
3
],
"seq_lens_cpu"
:
[
2
,
3
,
4
,
3
],
"num_computed_tokens_cpu"
:
[
0
,
1
,
2
,
1
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
104
,
105
,
106
,
107
,
108
,
109
,
110
,
111
,
],
"max_seq_len"
:
4
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
19
,
20
,
21
,
30
,
31
,
32
,
33
,
40
,
41
,
42
],
"prev_positions"
:
[
0
,
1
,
0
,
1
,
2
,
0
,
2
,
3
,
4
,
0
,
1
,
2
],
"token_indices_to_sample"
:
[
1
,
3
,
6
,
10
],
"seq_lens"
:
[
2
,
2
,
4
,
3
],
"slot_mapping"
:
[
100
,
101
,
119
,
102
,
103
,
105
,
106
,
107
,
108
,
109
,
110
,
111
,
],
"cached"
:
[
{
"len"
:
2
,
"token_ids"
:
[
10
,
11
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
100
,
101
,
0
],
},
{
"len"
:
2
,
"token_ids"
:
[
19
,
20
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
119
,
102
,
0
],
},
{
"len"
:
2
,
"token_ids"
:
[
30
,
31
,
0
],
"positions"
:
[
0
,
2
,
0
],
"slot_mapping"
:
[
105
,
106
,
0
],
},
{
"len"
:
2
,
"token_ids"
:
[
40
,
41
,
0
],
"positions"
:
[
0
,
1
,
0
],
"slot_mapping"
:
[
109
,
110
,
0
],
},
],
},
},
{
"name"
:
"batch2_shift_0_and_2"
,
"batch_size"
:
2
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
],
"target_token_ids"
:
[
30
,
31
,
32
,
40
,
41
,
42
,
43
],
"target_positions"
:
[
0
,
1
,
2
,
0
,
3
,
4
,
5
],
"token_indices_to_sample"
:
[
2
,
4
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
3
,
7
],
"seq_lens"
:
[
3
,
4
],
"seq_lens_cpu"
:
[
3
,
4
],
"num_computed_tokens_cpu"
:
[
0
,
2
],
"slot_mapping"
:
[
100
,
101
,
102
,
200
,
201
,
202
,
203
],
"max_seq_len"
:
4
,
},
"expected"
:
{
"prev_token_ids"
:
[
30
,
31
,
32
,
40
,
41
,
42
,
43
],
"prev_positions"
:
[
0
,
1
,
2
,
0
,
3
,
4
,
5
],
"token_indices_to_sample"
:
[
2
,
4
],
"seq_lens"
:
[
3
,
4
],
"slot_mapping"
:
[
100
,
101
,
102
,
200
,
201
,
202
,
203
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
30
,
31
,
32
],
"positions"
:
[
0
,
1
,
2
],
"slot_mapping"
:
[
100
,
101
,
102
],
},
{
"len"
:
2
,
"token_ids"
:
[
40
,
41
,
0
],
"positions"
:
[
0
,
3
,
0
],
"slot_mapping"
:
[
200
,
201
,
0
],
},
],
},
},
{
"name"
:
"continue_req_shift_1_cache_tail_3"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
3
,
"token_ids"
:
[
70
,
71
,
72
],
"positions"
:
[
7
,
8
,
9
],
"slot_mapping"
:
[
170
,
171
,
172
],
}
],
"target_token_ids"
:
[
100
,
101
,
102
,
103
,
104
],
"target_positions"
:
[
10
,
11
,
12
,
13
,
14
],
"token_indices_to_sample"
:
[
3
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
5
],
"seq_lens"
:
[
5
],
"seq_lens_cpu"
:
[
5
],
"num_computed_tokens_cpu"
:
[
0
],
"slot_mapping"
:
[
200
,
201
,
202
,
203
,
204
],
"max_seq_len"
:
5
,
},
"expected"
:
{
"prev_token_ids"
:
[
72
,
100
,
101
,
102
,
103
],
"prev_positions"
:
[
9
,
10
,
11
,
12
,
13
],
"token_indices_to_sample"
:
[
4
],
"seq_lens"
:
[
4
],
"slot_mapping"
:
[
172
,
200
,
201
,
202
,
203
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
101
,
102
,
103
],
"positions"
:
[
11
,
12
,
13
],
"slot_mapping"
:
[
201
,
202
,
203
],
}
],
},
},
{
"name"
:
"continue_req_shift_3_cache_tail_3"
,
"batch_size"
:
1
,
"initial_cache"
:
[
{
"len"
:
3
,
"token_ids"
:
[
270
,
271
,
272
],
"positions"
:
[
27
,
28
,
29
],
"slot_mapping"
:
[
370
,
371
,
372
],
}
],
"target_token_ids"
:
[
300
,
301
,
302
,
303
,
304
,
305
,
306
],
"target_positions"
:
[
30
,
31
,
32
,
33
,
34
,
35
,
36
],
"token_indices_to_sample"
:
[
3
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
7
],
"seq_lens"
:
[
7
],
"seq_lens_cpu"
:
[
7
],
"num_computed_tokens_cpu"
:
[
0
],
"slot_mapping"
:
[
400
,
401
,
402
,
403
,
404
,
405
,
406
],
"max_seq_len"
:
7
,
},
"expected"
:
{
"prev_token_ids"
:
[
270
,
271
,
272
,
300
,
301
,
302
,
303
],
"prev_positions"
:
[
27
,
28
,
29
,
30
,
31
,
32
,
33
],
"token_indices_to_sample"
:
[
6
],
"seq_lens"
:
[
4
],
"slot_mapping"
:
[
370
,
371
,
372
,
400
,
401
,
402
,
403
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
301
,
302
,
303
],
"positions"
:
[
31
,
32
,
33
],
"slot_mapping"
:
[
401
,
402
,
403
],
}
],
},
},
{
"name"
:
"batch3_mixed_shifts_0_1_2_all_full_cache"
,
"batch_size"
:
3
,
"initial_cache"
:
[
{
"len"
:
0
,
"token_ids"
:
[
0
,
0
,
0
],
"positions"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
0
,
0
,
0
],
},
{
"len"
:
3
,
"token_ids"
:
[
70
,
71
,
72
],
"positions"
:
[
7
,
8
,
9
],
"slot_mapping"
:
[
170
,
171
,
172
],
},
{
"len"
:
3
,
"token_ids"
:
[
270
,
271
,
272
],
"positions"
:
[
17
,
18
,
19
],
"slot_mapping"
:
[
370
,
371
,
372
],
},
],
"target_token_ids"
:
[
10
,
11
,
12
,
13
,
100
,
101
,
102
,
103
,
104
,
200
,
201
,
202
,
203
,
204
,
205
,
],
"target_positions"
:
[
0
,
1
,
2
,
3
,
10
,
11
,
12
,
13
,
14
,
20
,
21
,
22
,
23
,
24
,
25
,
],
"token_indices_to_sample"
:
[
3
,
7
,
12
],
"common_attn_metadata"
:
{
"query_start_loc"
:
[
0
,
4
,
9
,
15
],
"seq_lens"
:
[
4
,
5
,
6
],
"seq_lens_cpu"
:
[
4
,
5
,
6
],
"num_computed_tokens_cpu"
:
[
0
,
0
,
0
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
200
,
201
,
202
,
203
,
204
,
300
,
301
,
302
,
303
,
304
,
305
,
],
"max_seq_len"
:
6
,
},
"expected"
:
{
"prev_token_ids"
:
[
10
,
11
,
12
,
13
,
72
,
100
,
101
,
102
,
103
,
271
,
272
,
200
,
201
,
202
,
203
,
],
"prev_positions"
:
[
0
,
1
,
2
,
3
,
9
,
10
,
11
,
12
,
13
,
18
,
19
,
20
,
21
,
22
,
23
,
],
"token_indices_to_sample"
:
[
3
,
8
,
14
],
"seq_lens"
:
[
4
,
4
,
4
],
"slot_mapping"
:
[
100
,
101
,
102
,
103
,
172
,
200
,
201
,
202
,
203
,
371
,
372
,
300
,
301
,
302
,
303
,
],
"cached"
:
[
{
"len"
:
3
,
"token_ids"
:
[
11
,
12
,
13
],
"positions"
:
[
1
,
2
,
3
],
"slot_mapping"
:
[
101
,
102
,
103
],
},
{
"len"
:
3
,
"token_ids"
:
[
101
,
102
,
103
],
"positions"
:
[
11
,
12
,
13
],
"slot_mapping"
:
[
201
,
202
,
203
],
},
{
"len"
:
3
,
"token_ids"
:
[
201
,
202
,
203
],
"positions"
:
[
21
,
22
,
23
],
"slot_mapping"
:
[
301
,
302
,
303
],
},
],
},
},
]
def
_run_adjust_input_case
(
proposer_stub
,
case
,
layer_num
):
proposer
=
proposer_stub
proposer
.
layer_num
=
layer_num
max_shift
=
proposer
.
layer_num
device
=
torch
.
device
(
"cuda"
)
initial_cache
=
case
[
"initial_cache"
]
batch_size
=
case
[
"batch_size"
]
assert
len
(
initial_cache
)
==
batch_size
meta
=
case
[
"common_attn_metadata"
]
query_start_loc_cpu
=
torch
.
tensor
(
meta
[
"query_start_loc"
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
common_attn_metadata
=
SimpleNamespace
(
query_start_loc
=
query_start_loc_cpu
.
to
(
device
=
device
),
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
torch
.
tensor
(
meta
[
"seq_lens"
],
dtype
=
torch
.
int32
,
device
=
device
),
seq_lens_cpu
=
torch
.
tensor
(
meta
[
"seq_lens_cpu"
],
dtype
=
torch
.
int32
,
device
=
"cpu"
),
num_computed_tokens_cpu
=
torch
.
tensor
(
meta
[
"num_computed_tokens_cpu"
],
dtype
=
torch
.
int32
,
device
=
"cpu"
),
slot_mapping
=
torch
.
tensor
(
meta
[
"slot_mapping"
],
dtype
=
torch
.
int64
,
device
=
device
),
max_seq_len
=
meta
[
"max_seq_len"
],
)
target_token_ids
=
torch
.
tensor
(
case
[
"target_token_ids"
],
dtype
=
torch
.
int32
,
device
=
device
)
target_positions
=
torch
.
tensor
(
case
[
"target_positions"
],
dtype
=
torch
.
int64
,
device
=
device
)
target_hidden_states
=
torch
.
arange
(
0
,
target_token_ids
.
numel
()
*
HIDDEN_SIZE
,
dtype
=
torch
.
float32
,
device
=
device
).
reshape
(
-
1
,
HIDDEN_SIZE
)
token_indices_to_sample
=
torch
.
tensor
(
case
[
"token_indices_to_sample"
],
dtype
=
torch
.
int32
,
device
=
device
)
multi_layer_eagle_metadata
=
_make_multi_layer_eagle_metadata
(
initial_cache
=
initial_cache
,
max_shift
=
max_shift
,
device
=
device
,
)
prev_token_ids
,
prev_positions
,
_
,
_
=
proposer
.
adjust_input
(
batch_size
=
batch_size
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
common_attn_metadata
=
common_attn_metadata
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
)
expected
=
case
[
"expected"
]
assert
len
(
expected
[
"cached"
])
==
batch_size
assert
prev_token_ids
.
cpu
().
tolist
()
==
expected
[
"prev_token_ids"
]
assert
prev_positions
.
cpu
().
tolist
()
==
expected
[
"prev_positions"
]
assert
token_indices_to_sample
.
cpu
().
tolist
()
==
expected
[
"token_indices_to_sample"
]
assert
common_attn_metadata
.
seq_lens
.
cpu
().
tolist
()
==
expected
[
"seq_lens"
]
assert
common_attn_metadata
.
slot_mapping
.
cpu
().
tolist
()
==
expected
[
"slot_mapping"
]
for
row
,
cached_expect
in
enumerate
(
expected
[
"cached"
]):
assert
cached_expect
[
"len"
]
<=
max_shift
assert
(
len
(
cached_expect
[
"token_ids"
])
==
len
(
cached_expect
[
"positions"
])
==
len
(
cached_expect
[
"slot_mapping"
])
==
max_shift
)
cache_len
=
int
(
cached_expect
[
"len"
])
assert
int
(
multi_layer_eagle_metadata
.
cached_len
[
row
].
item
())
==
cache_len
assert
all
(
v
==
0
for
v
in
cached_expect
[
"token_ids"
][
cache_len
:])
assert
all
(
v
==
0
for
v
in
cached_expect
[
"positions"
][
cache_len
:])
assert
all
(
v
==
0
for
v
in
cached_expect
[
"slot_mapping"
][
cache_len
:])
assert
(
multi_layer_eagle_metadata
.
cached_token_ids
[
row
].
cpu
().
tolist
()
==
cached_expect
[
"token_ids"
]
)
assert
(
multi_layer_eagle_metadata
.
cached_positions
[
row
].
cpu
().
tolist
()
==
cached_expect
[
"positions"
]
)
assert
(
multi_layer_eagle_metadata
.
cached_slot_mappings
[
row
].
cpu
().
tolist
()
==
cached_expect
[
"slot_mapping"
]
)
@
pytest
.
mark
.
parametrize
(
"case"
,
LAYER3_CASES
,
ids
=
[
case
[
"name"
]
for
case
in
LAYER3_CASES
]
)
def
test_adjust_input_layer3_cases
(
proposer_stub
,
case
):
_run_adjust_input_case
(
proposer_stub
,
case
,
layer_num
=
3
)
vllm/config/speculative.py
View file @
84ac1d27
...
...
@@ -80,6 +80,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp
:
bool
=
False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
...
...
@@ -493,7 +497,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self
.
method
=
"mtp"
if
self
.
num_speculative_tokens
>
1
:
if
(
self
.
enable_multi_layers_mtp
is
False
and
self
.
num_speculative_tokens
>
1
):
logger
.
warning
(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
...
...
vllm/entrypoints/anthropic/serving.py
View file @
84ac1d27
...
...
@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat):
if
isinstance
(
msg
.
content
,
str
):
openai_msg
[
"content"
]
=
msg
.
content
else
:
cls
.
_convert_message_content
(
msg
,
openai_msg
,
openai_messages
)
# Handle complex content blocks
content_parts
:
list
[
dict
[
str
,
Any
]]
=
[]
tool_calls
:
list
[
dict
[
str
,
Any
]]
=
[]
reasoning_parts
:
list
[
str
]
=
[]
for
block
in
msg
.
content
:
if
block
.
type
==
"text"
and
block
.
text
:
content_parts
.
append
({
"type"
:
"text"
,
"text"
:
block
.
text
})
elif
block
.
type
==
"image"
and
block
.
source
:
content_parts
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
block
.
source
.
get
(
"data"
,
""
)},
}
)
elif
block
.
type
==
"thinking"
and
block
.
thinking
is
not
None
:
reasoning_parts
.
append
(
block
.
thinking
)
elif
block
.
type
==
"tool_use"
:
# Convert tool use to function call format
tool_call
=
{
"id"
:
block
.
id
or
f
"call_
{
int
(
time
.
time
())
}
"
,
"type"
:
"function"
,
"function"
:
{
"name"
:
block
.
name
or
""
,
"arguments"
:
json
.
dumps
(
block
.
input
or
{}),
},
}
tool_calls
.
append
(
tool_call
)
elif
block
.
type
==
"tool_result"
:
if
msg
.
role
==
"user"
:
openai_messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
block
.
id
or
""
,
"content"
:
str
(
block
.
content
)
if
block
.
content
else
""
,
}
)
else
:
# Assistant tool result becomes regular text
tool_result_text
=
(
str
(
block
.
content
)
if
block
.
content
else
""
)
content_parts
.
append
(
{
"type"
:
"text"
,
"text"
:
f
"Tool result:
{
tool_result_text
}
"
,
}
)
if
reasoning_parts
:
openai_msg
[
"reasoning"
]
=
""
.
join
(
reasoning_parts
)
# Add tool calls to the message if any
if
tool_calls
:
openai_msg
[
"tool_calls"
]
=
tool_calls
# type: ignore
# Add content parts if any
if
content_parts
:
if
len
(
content_parts
)
==
1
and
content_parts
[
0
][
"type"
]
==
"text"
:
openai_msg
[
"content"
]
=
content_parts
[
0
][
"text"
]
else
:
openai_msg
[
"content"
]
=
content_parts
# type: ignore
elif
not
tool_calls
and
not
reasoning_parts
:
continue
openai_messages
.
append
(
openai_msg
)
...
...
@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat):
first_item
=
True
finish_reason
=
None
state
=
_ActiveBlockState
()
content_block_index
=
0
active_block_type
:
str
|
None
=
None
active_block_index
:
int
|
None
=
None
active_block_signature
:
str
|
None
=
None
signature_emitted
=
False
active_tool_use_id
:
str
|
None
=
None
# Map from tool call index to tool_use_id
tool_index_to_id
:
dict
[
int
,
str
]
=
{}
def
stop_active_block
():
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
events
:
list
[
str
]
=
[]
if
state
.
block_type
is
None
:
if
active_
block_type
is
None
:
return
events
if
(
state
.
block_type
==
"thinking"
and
state
.
block_signature
is
not
None
and
not
state
.
signature_emitted
active_
block_type
==
"thinking"
and
active_
block_signature
is
not
None
and
not
signature_emitted
):
chunk
=
AnthropicStreamEvent
(
index
=
state
.
block_index
,
index
=
active_
block_index
,
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"signature_delta"
,
signature
=
state
.
block_signature
,
signature
=
active_
block_signature
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_delta"
))
state
.
signature_emitted
=
True
signature_emitted
=
True
stop_chunk
=
AnthropicStreamEvent
(
index
=
state
.
block_index
,
index
=
active_
block_index
,
type
=
"content_block_stop"
,
)
data
=
stop_chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_stop"
))
state
.
reset
()
state
.
content_block_index
+=
1
active_block_type
=
None
active_block_index
=
None
active_block_signature
=
None
signature_emitted
=
False
active_tool_use_id
=
None
content_block_index
+=
1
return
events
def
start_block
(
block
:
AnthropicContentBlock
):
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
chunk
=
AnthropicStreamEvent
(
index
=
state
.
content_block_index
,
index
=
content_block_index
,
type
=
"content_block_start"
,
content_block
=
block
,
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
event
=
wrap_data_with_event
(
data
,
"content_block_start"
)
state
.
start
(
block
)
active_block_type
=
block
.
type
active_block_index
=
content_block_index
if
block
.
type
==
"thinking"
:
active_block_signature
=
uuid
.
uuid4
().
hex
signature_emitted
=
False
active_tool_use_id
=
None
elif
block
.
type
==
"tool_use"
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
block
.
id
else
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
None
return
event
async
for
item
in
generator
:
...
...
@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if
reasoning_delta
==
""
:
pass
else
:
if
state
.
block_type
!=
"thinking"
:
if
active_
block_type
!=
"thinking"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
...
...
@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if
origin_chunk
.
choices
[
0
].
delta
.
content
==
""
:
pass
else
:
if
state
.
block_type
!=
"text"
:
if
active_
block_type
!=
"text"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
...
...
@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat):
else
None
)
if
(
state
.
tool_use_id
!=
tool_call
.
id
active_
tool_use_id
!=
tool_call
.
id
and
tool_name
is
not
None
):
for
event
in
stop_active_block
():
...
...
@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat):
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
and
state
.
tool_use_id
==
tool_call
.
id
and
active_
tool_use_id
==
tool_call
.
id
):
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat):
tool_use_id
is
not
None
and
tool_call
.
function
and
tool_call
.
function
.
arguments
and
state
.
tool_use_id
==
tool_use_id
and
active_
tool_use_id
==
tool_use_id
):
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
84ac1d27
...
...
@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing):
index
=
0
if
(
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
tool_parser
and
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
,
tool_parser
)
and
tool_parser
):
latest_delta_len
=
0
if
(
...
...
@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing):
self
,
delta_message
:
DeltaMessage
|
None
,
output
:
CompletionOutput
,
tool_parser
:
ToolParser
|
None
=
None
,
)
->
bool
:
"""
Check to see if we should check for unstreamed tool arguments tokens.
...
...
@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing):
# include a function that has arguments
output
.
finish_reason
is
not
None
and
self
.
enable_auto_tools
and
tool_parser
is
not
None
and
tool_parser
.
parser_should_check_for_unstreamed_tool_arg_tokens
()
and
self
.
tool_parser
and
delta_message
and
delta_message
.
tool_calls
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
84ac1d27
...
...
@@ -262,6 +262,7 @@ def select_fp8_moe_backend(
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
)
supported
=
True
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
k_cls
...
...
vllm/model_executor/models/step3p5_mtp.py
View file @
84ac1d27
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
...
...
@@ -40,9 +41,11 @@ class SharedHead(nn.Module):
return
self
.
norm
(
hidden_states
)
@
support_torch_compile
class
Step3p5AMultiTokenPredictorLayer
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
,
)
->
None
:
...
...
@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self
.
enorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
hnorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared
_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
lm
_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
Step3p5DecoderLayer
(
vllm_config
,
prefix
=
f
"
{
prefix
}
.mtp_block"
,
...
...
@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
embed_tokens
:
VocabParallelEmbedding
|
None
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
if
inputs_embeds
is
None
:
assert
embed_tokens
is
not
None
inputs_embeds
=
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
...
...
@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self
.
layers
=
torch
.
nn
.
ModuleDict
(
{
str
(
idx
):
Step3p5AMultiTokenPredictorLayer
(
vllm_config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
...
...
@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
self
.
embed_tokens
,
current_step_idx
,
)
...
...
@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared
_head
.
head
,
mtp_layer
.
shared
_head
(
hidden_states
)
mtp_layer
.
lm
_head
.
head
,
mtp_layer
.
lm
_head
(
hidden_states
)
)
return
logits
...
...
@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module):
name
=
name
.
replace
(
".transformer."
,
"."
)
if
"shared_head"
in
name
:
name
=
name
.
replace
(
"shared_head.output"
,
"shared_head.head"
)
name
=
name
.
replace
(
"shared_head"
,
"lm_head"
)
if
"embed_tokens"
in
name
:
assert
(
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
)
...
...
vllm/tool_parsers/abstract_tool_parser.py
View file @
84ac1d27
...
...
@@ -118,6 +118,12 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return
True
class
ToolParserManager
:
"""
...
...
vllm/tool_parsers/step3p5_tool_parser.py
View file @
84ac1d27
...
...
@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
json
import
uuid
from
collections.abc
import
Sequence
from
typing
import
Any
from
xml.parsers.expat
import
ParserCreate
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
...
...
@@ -23,1500 +22,1144 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
)
logger
=
init_logger
(
__name__
)
class
StreamingXMLToolCallParser
:
"""
Simplified streaming XML tool call parser
Supports streaming input, parsing, and output
"""
class
Step3p5ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
def
__init__
(
self
):
self
.
reset_streaming_state
()
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
# Override base class type - we use string IDs for tool calls
self
.
current_tool_id
:
str
|
None
=
None
# type: ignore
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
# Tool configuration information
self
.
tools
:
list
[
ChatCompletionToolsParam
]
|
None
=
None
# Sentinel tokens for streaming mode
self
.
tool_call_start_token
:
str
=
"<tool_call>"
self
.
tool_call_end_token
:
str
=
"</tool_call>"
self
.
function_start_token
:
str
=
"<function="
self
.
tool_call_prefix
:
str
=
"<function="
self
.
function_end_token
:
str
=
"</function>"
self
.
parameter_
start_token
:
str
=
"<parameter="
self
.
parameter_
prefix
:
str
=
"<parameter="
self
.
parameter_end_token
:
str
=
"</parameter>"
self
.
is_tool_call_started
:
bool
=
False
self
.
failed_count
:
int
=
0
def
reset_streaming_state
(
self
):
"""Reset streaming parsing state"""
self
.
deltas
=
[]
# state for streaming
self
.
tool_call_index
=
0
self
.
current_call_id
=
None
self
.
last_completed_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
streaming_buffer
=
""
self
.
last_processed_pos
=
0
self
.
text_content_buffer
=
""
# state for preprocessing and deferred parsing
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
# recreate parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
def
parse_single_streaming_chunks
(
self
,
xml_chunk
:
str
)
->
DeltaMessage
:
"""
Parse single streaming XML chunk and return Delta response
This is the actual streaming interface that receives chunks
one by one and maintains internal state
# Enhanced streaming state - reset for each new message
self
.
_reset_streaming_state
()
Args:
xml_chunk: Single XML chunk string
Returns:
DeltaMessage: Contains delta information generated by this chunk,
returns empty response if no
comple
te elements
"""
# Record delta count before processing
initial_delta_count
=
len
(
self
.
deltas
)
entry_call_id
=
self
.
current_call_id
entry_tool_call_index
=
self
.
tool_call_index
# Regex patterns
self
.
tool_call_complete_regex
=
re
.
compile
(
r
"<tool_call>(.*?)</tool_call>"
,
re
.
DOTALL
)
self
.
tool_call_function_regex
=
re
.
comp
i
le
(
r
"<function(?:=|\s+)?(.*?)</function>"
,
re
.
DOTALL
)
self
.
tool_call_parameter_regex
=
re
.
compile
(
r
"<parameter=(.*?)</parameter>"
,
re
.
DOTALL
)
self
.
streaming_buffer
+=
xml_chunk
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
found_elements
=
self
.
_process_complete_xml_elements
()
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
fallback_call_id
=
None
if
entry_call_id
is
not
None
:
if
(
self
.
current_call_id
==
entry_call_id
and
self
.
tool_call_index
==
entry_tool_call_index
if
self
.
tool_call_start_token_id
is
None
or
self
.
tool_call_end_token_id
is
None
:
raise
RuntimeError
(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self
.
eos_token_id
=
getattr
(
self
.
model_tokenizer
,
"eos_token_id"
,
None
)
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
"""Generate a unique tool call ID."""
return
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Skip the remaining_call calculation in serving
"""
return
False
def
_reset_streaming_state
(
self
):
"""Reset all streaming state for a new request."""
self
.
_processed_length
:
int
=
0
# Position of last processed character
self
.
_tool_call_index
:
int
=
0
# Number of tool calls processed so far
self
.
streaming_request
=
None
# Current request being processed
def
_get_arguments_config
(
self
,
func_name
:
str
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
)
->
dict
:
"""Extract argument configuration for a function."""
if
tools
is
None
:
return
{}
for
config
in
tools
:
if
not
hasattr
(
config
,
"type"
)
or
not
(
hasattr
(
config
,
"function"
)
and
hasattr
(
config
.
function
,
"name"
)
):
fallback_call_id
=
entry_call_id
continue
if
config
.
type
==
"function"
and
config
.
function
.
name
==
func_name
:
if
not
hasattr
(
config
.
function
,
"parameters"
):
return
{}
params
=
config
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
return
params
[
"properties"
]
elif
isinstance
(
params
,
dict
):
return
params
else
:
return
{}
logger
.
warning
(
"Tool '%s' is not defined in the tools list."
,
func_name
)
return
{}
def
_convert_param_value
(
self
,
param_value
:
str
,
param_name
:
str
,
param_config
:
dict
,
func_name
:
str
)
->
Any
:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if
param_value
.
lower
()
==
"null"
:
return
None
if
param_name
not
in
param_config
:
if
param_config
!=
{}:
logger
.
warning
(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value."
,
param_name
,
func_name
,
)
return
param_value
if
(
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]
):
param_type
=
str
(
param_config
[
param_name
][
"type"
]).
strip
().
lower
()
else
:
param_type
=
"string"
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
self
.
current_call_id
is
not
None
and
self
.
tool_call_index
==
entry_tool_call_index
+
1
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
):
fallback_call_id
=
self
.
current_call_id
if
found_elements
:
# If complete elements found, check if end events were missed
# some tags may not have been triggered
try
:
new_deltas
=
self
.
deltas
[
initial_delta_count
:]
# If this chunk contains </function>
# but didn't generate '}', then complete it
if
(
fallback_call_id
is
not
None
and
self
.
function_end_token
in
xml_chunk
):
# - Added '}' (non-empty parameter ending)
# - Added '{}' (empty parameter function)
has_function_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
function
and
tc
.
id
==
fallback_call_id
and
isinstance
(
tc
.
function
.
arguments
,
str
)
and
(
tc
.
function
.
arguments
in
(
"}"
,
"{}"
))
)
for
tc
in
td
.
tool_calls
)
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
try
:
float_value
=
float
(
param_value
)
if
float_value
.
is_integer
():
return
int
(
float_value
)
except
(
ValueError
,
TypeError
):
pass
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
bool
):
return
int
(
literal_value
)
if
isinstance
(
literal_value
,
(
int
,
float
)):
return
(
int
(
literal_value
)
if
float
(
literal_value
).
is_integer
()
else
literal_value
)
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
=
float
(
param_value
)
return
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
)
except
(
ValueError
,
TypeError
):
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
(
int
,
float
)):
return
(
float
(
literal_value
)
if
float
(
literal_value
)
-
int
(
float
(
literal_value
))
!=
0
else
int
(
float
(
literal_value
))
)
for
td
in
new_deltas
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
normalized_value
=
param_value
.
strip
().
lower
()
if
normalized_value
in
[
"true"
,
"false"
]:
return
normalized_value
==
"true"
if
normalized_value
in
[
"1"
,
"0"
]:
return
normalized_value
==
"1"
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
bool
):
return
literal_value
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
param_value
else
:
if
(
param_type
in
[
"object"
,
"array"
,
"arr"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
):
try
:
param_value
=
json
.
loads
(
param_value
)
return
param_value
except
(
json
.
JSONDecodeError
,
TypeError
,
ValueError
):
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
if
isinstance
(
literal_value
,
(
list
,
dict
)):
return
literal_value
if
isinstance
(
literal_value
,
(
tuple
,
set
)):
return
list
(
literal_value
)
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
if
not
has_function_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
return
param_value
try
:
literal_value
=
ast
.
literal_eval
(
param_value
)
# safer
if
isinstance
(
literal_value
,
(
tuple
,
set
)):
return
list
(
literal_value
)
if
(
fallback_call_id
is
not
None
and
self
.
too
l_
c
al
l_end_token
in
xml_chunk
isinstance
(
literal_value
,
(
list
,
dict
,
str
,
int
,
float
,
bool
))
or
litera
l_
v
al
ue
is
None
):
has_toolcall_close
=
any
(
(
td
.
tool_calls
and
any
(
(
tc
.
type
==
"function"
and
tc
.
function
and
tc
.
function
.
arguments
==
""
and
tc
.
id
==
fallback_call_id
)
for
tc
in
td
.
tool_calls
)
)
for
td
in
new_deltas
)
if
not
has_toolcall_close
:
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
self
.
_end_element
(
"tool_call"
)
except
Exception
as
e
:
logger
.
warning
(
"Error with fallback parsing: %s"
,
e
)
# Merge newly generated deltas into single response
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
return
literal_value
except
(
ValueError
,
SyntaxError
,
TypeError
):
pass
logger
.
warning
(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', returning raw string."
,
param_value
,
param_name
,
func_name
,
)
return
result_delta
return
param_value
def
_parse_parameters_fallback
(
self
,
parameters
:
str
,
allowed_param_names
:
set
[
str
]
|
None
=
None
,
)
->
list
[
tuple
[
str
,
str
]]:
"""Fallback parser for malformed parameter tags."""
param_pairs
:
list
[
tuple
[
str
,
str
]]
=
[]
pos
=
0
while
True
:
start
=
parameters
.
find
(
self
.
parameter_prefix
,
pos
)
if
start
==
-
1
:
break
name_start
=
start
+
len
(
self
.
parameter_prefix
)
name_end
=
parameters
.
find
(
">"
,
name_start
)
if
name_end
==
-
1
:
newline_idx
=
parameters
.
find
(
"
\n
"
,
name_start
)
end_tag
=
parameters
.
find
(
self
.
parameter_end_token
,
name_start
)
next_param
=
parameters
.
find
(
self
.
parameter_prefix
,
name_start
)
candidates
=
[
idx
for
idx
in
[
newline_idx
,
end_tag
,
next_param
]
if
idx
!=
-
1
]
if
not
candidates
:
break
name_end
=
min
(
candidates
)
value_start
=
name_end
else
:
value_start
=
name_end
+
1
param_name
=
parameters
[
name_start
:
name_end
].
strip
()
next_param
=
parameters
.
find
(
self
.
parameter_prefix
,
value_start
)
end_tag
=
parameters
.
find
(
self
.
parameter_end_token
,
value_start
)
if
end_tag
==
-
1
or
(
next_param
!=
-
1
and
next_param
<
end_tag
):
end
=
next_param
if
next_param
!=
-
1
else
len
(
parameters
)
pos
=
end
else
:
end
=
end_tag
pos
=
end
+
len
(
self
.
parameter_end_token
)
param_value
=
parameters
[
value_start
:
end
]
if
allowed_param_names
is
None
or
param_name
in
allowed_param_names
:
param_pairs
.
append
((
param_name
,
param_value
))
return
param_pairs
def
_is_valid_json_arguments
(
self
,
arguments
:
str
)
->
bool
:
"""Check if arguments can be loaded as JSON."""
try
:
json
.
loads
(
arguments
)
except
Exception
:
return
False
return
True
def
_parse_xml_function_call
(
self
,
function_call_str
:
str
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
)
->
ToolCall
|
None
:
# Extract function name
end_index
=
function_call_str
.
index
(
">"
)
# check empty function name
function_name
=
function_call_str
[:
end_index
].
strip
()
if
function_name
.
startswith
(
"="
):
function_name
=
function_name
.
lstrip
(
"="
).
strip
()
if
not
function_name
or
function_name
.
strip
(
"'
\"
"
)
==
""
:
logger
.
warning
(
"Empty function name in tool call."
)
return
None
if
function_name
[
0
]
in
"
\"
'"
and
function_name
[
-
1
]
==
function_name
[
0
]:
function_name
=
function_name
[
1
:
-
1
].
strip
()
if
not
function_name
:
logger
.
warning
(
"Empty function name in tool call."
)
return
None
param_config
=
self
.
_get_arguments_config
(
function_name
,
tools
)
parameters
=
function_call_str
[
end_index
+
1
:]
param_dict
=
{}
match_texts
=
self
.
tool_call_parameter_regex
.
findall
(
parameters
)
use_fallback
=
False
if
match_texts
:
for
match_text
in
match_texts
:
if
self
.
parameter_prefix
in
match_text
or
">"
not
in
match_text
:
use_fallback
=
True
break
else
:
# No complete elements, check if there's unoutput text content
if
self
.
text_content_buffer
and
self
.
tool_call_index
==
0
:
# Has text content but no tool_call yet, output text content
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer to avoid duplicate output
self
.
text_content_buffer
=
""
return
text_delta
# If this chunk contains end tags but wasn't triggered by parser,
# manually complete end events
# Only execute when still on the same call as when entered,
# to prevent accidentally closing new calls
# in multi <tool_call> scenarios
if
fallback_call_id
is
not
None
and
(
self
.
function_end_token
in
xml_chunk
or
self
.
tool_call_end_token
in
xml_chunk
):
# Close potentially unclosed element
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
function_end_token
in
xml_chunk
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
if
self
.
tool_call_end_token
in
xml_chunk
:
self
.
_end_element
(
"tool_call"
)
# Return the merged delta result generated by this fallback
result_delta
=
self
.
_merge_new_deltas_to_single_response
(
initial_delta_count
)
return
result_delta
use_fallback
=
self
.
parameter_prefix
in
parameters
# No complete elements, return empty response
return
DeltaMessage
(
content
=
None
)
if
use_fallback
:
allowed_param_names
=
(
set
(
param_config
.
keys
())
if
isinstance
(
param_config
,
dict
)
and
param_config
else
None
)
param_pairs
=
self
.
_parse_parameters_fallback
(
parameters
,
allowed_param_names
)
else
:
param_pairs
=
[]
for
match_text
in
match_texts
:
idx
=
match_text
.
index
(
">"
)
param_name
=
match_text
[:
idx
]
param_value
=
str
(
match_text
[
idx
+
1
:])
param_pairs
.
append
((
param_name
,
param_value
))
for
param_name
,
param_value
in
param_pairs
:
# Remove prefix and trailing \n
if
param_value
.
startswith
(
"
\n
"
):
param_value
=
param_value
[
1
:]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
param_dict
[
param_name
]
=
self
.
_convert_param_value
(
param_value
,
param_name
,
param_config
,
function_name
)
def
_escape_xml_special_chars
(
self
,
text
:
str
)
->
str
:
"""
Escape XML special characters
Args:
text: Original text
Returns:
Escaped text
"""
xml_escapes
=
{
"&"
:
"&"
,
"<"
:
"<"
,
">"
:
">"
,
'"'
:
"""
,
"'"
:
"'"
,
}
try
:
arguments
=
json
.
dumps
(
param_dict
,
ensure_ascii
=
False
)
except
Exception
as
e
:
logger
.
warning
(
"Error in converting parameter value: %s"
,
e
)
return
None
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
arguments
),
)
for
char
,
escape
in
xml_escapes
.
items
():
text
=
text
.
replace
(
char
,
escape
)
def
_get_function_calls
(
self
,
model_output
:
str
)
->
list
[
str
]:
# Find all tool calls
raw_tool_calls
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
return
text
# if no closed tool_call tags found, return empty list
if
len
(
raw_tool_calls
)
==
0
:
return
[]
def
_process_complete_xml_elements
(
self
)
->
bool
:
"""
Process complete XML elements in buffer
raw_function_calls
=
[]
for
tool_call
in
raw_tool_calls
:
function_matches
=
self
.
tool_call_function_regex
.
findall
(
tool_call
)
raw_function_calls
.
extend
(
function_matches
)
Returns:
bool: Whether complete elements were found and processed
"""
found_any
=
False
return
raw_function_calls
while
self
.
last_processed_pos
<
len
(
self
.
streaming_buffer
):
# Find next complete xml element
element
,
end_pos
=
self
.
_find_next_complete_element
(
self
.
last_processed_pos
)
if
element
is
None
:
# No complete element found, wait for more data
break
def
_check_format
(
self
,
model_output
:
str
)
->
bool
:
"""Check if model output contains properly formatted tool call.
# Check if this element should be skipped
if
self
.
_should_skip_element
(
element
):
self
.
last_processed_pos
=
end_pos
continue
Requirements:
1. Must have closed tool_call tags (<tool_call>...</tool_call>)
2. Must have closed function tags (<function=...</function>)
3. If parameter tags exist, they must be closed and correct
# Found complete XML element, process it
try
:
preprocessed_element
=
self
.
_preprocess_xml_chunk
(
element
)
# Check if this is the first tool_call start
if
(
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
or
preprocessed_element
.
strip
().
startswith
(
"<function name="
)
)
and
self
.
tool_call_index
==
0
)
and
self
.
text_content_buffer
:
# First tool_call starts,
# output previously collected text content first
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Clear buffer for potential subsequent text content
self
.
text_content_buffer
=
""
# If a new tool_call starts and
# there are already completed tool_calls with function name
if
(
preprocessed_element
.
strip
().
startswith
(
"<tool_call>"
)
and
self
.
tool_call_index
>
0
and
self
.
current_call_id
and
self
.
current_function_name
):
# Reset parser state but preserve generated deltas
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
if
self
.
current_function_open
:
self
.
_end_element
(
"function"
)
# Output final tool_call tail delta
final_delta
=
DeltaMessage
(
role
=
None
,
content
=
None
,
reasoning
=
None
,
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
)
],
)
self
.
_emit_delta
(
final_delta
)
# Reset XML parser and current call state
self
.
_reset_xml_parser_after_tool_call
()
# Parse preprocessed element
self
.
parser
.
Parse
(
preprocessed_element
,
False
)
found_any
=
True
Returns True if the format is valid, False otherwise.
"""
# Check 1: Must have closed tool_call tags
tool_call_matches
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
if
len
(
tool_call_matches
)
==
0
:
return
False
except
Exception
as
e
:
logger
.
warning
(
"Error when parsing XML elements: %s"
,
e
)
# Check 2: Must have closed function tags within tool_call
has_valid_function
=
False
for
tool_call_content
in
tool_call_matches
:
function_matches
=
self
.
tool_call_function_regex
.
findall
(
tool_call_content
)
if
len
(
function_matches
)
>
0
:
has_valid_function
=
True
# Check if there's an unclosed function tag
if
(
self
.
tool_call_prefix
in
tool_call_content
and
self
.
function_end_token
not
in
tool_call_content
):
return
False
# Update processed posi
tion
self
.
last_processed_pos
=
end_pos
if
not
has_valid_func
tion
:
return
False
return
found_any
# Check 3: If parameter tags exist, they must be closed and correct
for
tool_call_content
in
tool_call_matches
:
# Count opening and closing parameter tags
param_open_count
=
tool_call_content
.
count
(
self
.
parameter_prefix
)
param_close_count
=
tool_call_content
.
count
(
self
.
parameter_end_token
)
# If there are parameter tags, they must be balanced
if
param_open_count
>
0
:
if
param_open_count
!=
param_close_count
:
return
False
# Check if all parameter tags are properly closed using regex
param_matches
=
self
.
tool_call_parameter_regex
.
findall
(
tool_call_content
)
if
len
(
param_matches
)
!=
param_open_count
:
return
False
def
_fix_incomplete_tag_in_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Fallback: fix incomplete <parameter=xxx or <function=xxx tags
(missing >)
Examples: <parameter=-C: -> <parameter=-C>, <parameter=parameter=-n:
-> <parameter=-n>
Also handles missing = cases: <function xxx> -> <function=xxx>,
<functionxxx> -> <function=xxx>
Only fixes tags that pass validation (parameter exists in tool definition)
"""
# First, handle missing = cases for function tags
chunk
=
self
.
_fix_missing_equals_in_function_tag
(
chunk
)
return
True
for
tag_type
in
[
"parameter"
,
"function"
]:
pattern
=
f
"<
{
tag_type
}
="
if
pattern
not
in
chunk
:
def
_wrap_missing_tool_call_tags
(
self
,
model_output
:
str
)
->
str
:
"""Wrap bare <function=...></function> blocks with <tool_call> tags."""
if
(
self
.
tool_call_prefix
not
in
model_output
or
self
.
function_end_token
not
in
model_output
):
return
model_output
def
_wrap_bare_functions
(
text
:
str
)
->
str
:
pos
=
0
wrapped_parts
:
list
[
str
]
=
[]
while
True
:
func_idx
=
text
.
find
(
self
.
tool_call_prefix
,
pos
)
if
func_idx
==
-
1
:
wrapped_parts
.
append
(
text
[
pos
:])
break
end_idx
=
text
.
find
(
self
.
function_end_token
,
func_idx
)
if
end_idx
==
-
1
:
wrapped_parts
.
append
(
text
[
pos
:])
break
end_idx
+=
len
(
self
.
function_end_token
)
wrapped_parts
.
append
(
text
[
pos
:
func_idx
])
wrapped_parts
.
append
(
self
.
tool_call_start_token
)
wrapped_parts
.
append
(
text
[
func_idx
:
end_idx
])
wrapped_parts
.
append
(
self
.
tool_call_end_token
)
ws_idx
=
end_idx
while
ws_idx
<
len
(
text
)
and
text
[
ws_idx
].
isspace
():
ws_idx
+=
1
if
text
.
startswith
(
self
.
tool_call_end_token
,
ws_idx
):
if
ws_idx
>
end_idx
:
wrapped_parts
.
append
(
text
[
end_idx
:
ws_idx
])
pos
=
ws_idx
+
len
(
self
.
tool_call_end_token
)
else
:
pos
=
end_idx
return
""
.
join
(
wrapped_parts
)
tool_call_ranges
=
[
match
.
span
()
for
match
in
self
.
tool_call_complete_regex
.
finditer
(
model_output
)
]
if
not
tool_call_ranges
:
return
_wrap_bare_functions
(
model_output
)
wrapped_parts
:
list
[
str
]
=
[]
pos
=
0
for
start
,
end
in
tool_call_ranges
:
if
start
<
pos
:
continue
wrapped_parts
.
append
(
_wrap_bare_functions
(
model_output
[
pos
:
start
]))
wrapped_parts
.
append
(
model_output
[
start
:
end
])
pos
=
end
wrapped_parts
.
append
(
_wrap_bare_functions
(
model_output
[
pos
:]))
return
""
.
join
(
wrapped_parts
)
def
_normalize_prev_arguments
(
self
,
args_value
:
Any
)
->
Any
:
if
isinstance
(
args_value
,
str
):
try
:
return
json
.
loads
(
args_value
)
except
(
TypeError
,
ValueError
,
json
.
JSONDecodeError
):
return
args_value
return
args_value
def
_update_prev_tool_call_state
(
self
,
tool_calls
:
list
[
ToolCall
])
->
None
:
self
.
prev_tool_call_arr
.
clear
()
self
.
streamed_args_for_tool
.
clear
()
for
tool_call
in
tool_calls
:
if
not
tool_call
or
not
tool_call
.
function
:
continue
args_value
=
tool_call
.
function
.
arguments
if
isinstance
(
args_value
,
str
):
args_json
=
args_value
elif
args_value
is
None
:
args_json
=
""
else
:
try
:
args_json
=
json
.
dumps
(
args_value
,
ensure_ascii
=
False
)
except
(
TypeError
,
ValueError
):
args_json
=
str
(
args_value
)
prev_args
=
self
.
_normalize_prev_arguments
(
args_json
)
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
prev_args
,
}
)
try
:
expected_args_json
=
json
.
dumps
(
prev_args
,
ensure_ascii
=
False
)
except
(
TypeError
,
ValueError
):
expected_args_json
=
args_json
start_idx
=
chunk
.
find
(
pattern
)
after_tag
=
chunk
[
start_idx
:]
gt_pos
=
after_tag
.
find
(
">"
)
lt_pos
=
after_tag
.
find
(
"<"
,
len
(
pattern
))
# Serving may subtract the latest delta length from
# streamed_args_for_tool to detect unstreamed suffixes. Since this
# parser emits full arguments at once, store expected+actual so
# the subtraction yields expected_args_json and no resend occurs.
self
.
streamed_args_for_tool
.
append
(
expected_args_json
+
args_json
)
# Skip if already well-formed
if
(
gt_pos
!=
-
1
and
(
lt_pos
==
-
1
or
gt_pos
<
lt_pos
)
and
pattern
in
after_tag
[:
gt_pos
]
):
continue
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
try
:
origin_model_output
=
model_output
try
:
# Fallback: handle outputs without <tool_call> wrapper.
origin_model_output
=
self
.
_wrap_missing_tool_call_tags
(
origin_model_output
)
model_output
=
origin_model_output
except
Exception
:
pass
# Use streaming-like approach: process position by position
valid_tool_calls
=
[]
content_parts
=
[]
processed_length
=
0
while
processed_length
<
len
(
model_output
):
# Find next tool call start
tool_start_idx
=
self
.
_find_tool_call_start
(
model_output
,
processed_length
)
# Extract tag name (stop at space, newline, or <)
content
=
chunk
[
start_idx
+
len
(
pattern
)
:]
end_pos
=
next
(
(
i
for
i
,
ch
in
enumerate
(
content
)
if
ch
in
(
" "
,
"
\n
"
,
"<"
)),
len
(
content
),
)
tag_name
=
content
[:
end_pos
]
# Case 1: No more tool calls - add remaining as content
if
tool_start_idx
==
-
1
:
remaining
=
model_output
[
processed_length
:]
if
remaining
:
content_parts
.
append
(
remaining
)
break
# Case 2: Content before tool call
if
tool_start_idx
>
processed_length
:
content_before
=
model_output
[
processed_length
:
tool_start_idx
]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if
processed_length
>
0
:
text_before
=
model_output
[:
processed_length
]
if
(
text_before
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
content_before
.
strip
()
==
""
):
# Skip whitespace between tool calls
pass
else
:
content_parts
.
append
(
content_before
)
else
:
content_parts
.
append
(
content_before
)
if
not
tag_name
:
continue
# Case 3: Try to find complete tool call
tool_end_idx
=
self
.
_find_first_complete_tool_call_end
(
model_output
,
tool_start_idx
)
# Remove duplicate prefix: <parameter=parameter=xxx -> <parameter=xxx
if
tag_name
.
startswith
(
f
"
{
tag_type
}
="
):
tag_name
=
tag_name
[
len
(
tag_type
)
+
1
:]
# If tool call is incomplete - add remaining as content and stop
if
tool_end_idx
==
-
1
:
remaining
=
model_output
[
tool_start_idx
:]
if
remaining
:
content_parts
.
append
(
remaining
)
break
# Extract and try to parse the complete tool call
tool_call_text
=
model_output
[
tool_start_idx
:
tool_end_idx
]
parsed_result
=
self
.
extract_tool_calls_basic
(
tool_call_text
,
request
)
# If parsing succeeded, record the tool call(s)
if
parsed_result
.
tools_called
and
parsed_result
.
tool_calls
:
valid_tool_calls
.
extend
(
parsed_result
.
tool_calls
)
processed_length
=
tool_end_idx
else
:
# Parsing failed - treat this tool call as content
content_parts
.
append
(
tool_call_text
)
processed_length
=
tool_end_idx
# Remove trailing non-alphanumeric chars (keep - and _)
while
tag_name
and
not
(
tag_name
[
-
1
].
isalnum
()
or
tag_name
[
-
1
]
in
(
"-"
,
"_"
)
):
tag_name
=
tag_name
[:
-
1
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
_update_prev_tool_call_state
(
valid_tool_calls
)
if
not
tag_name
:
continu
e
# Combine content parts
content
=
""
.
join
(
content_parts
)
if
content_parts
else
Non
e
# Validate parameter exists in tool definition
if
tag_type
==
"parameter"
and
not
self
.
_validate_parameter_name
(
tag_name
):
continue
return
ExtractedToolCallInformation
(
tools_called
=
(
len
(
valid_tool_calls
)
>
0
),
tool_calls
=
valid_tool_calls
,
content
=
content
if
content
else
None
,
)
# Apply fix
chunk
=
chunk
.
replace
(
f
"<
{
tag_type
}
=
{
content
[:
end_pos
]
}
"
,
f
"<
{
tag_type
}
=
{
tag_name
}
>"
,
1
except
Exception
:
logger
.
warning
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
return
chunk
def
extract_tool_calls_basic
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
model_output
=
self
.
_wrap_missing_tool_call_tags
(
model_output
)
# Quick check to avoid unnecessary processing
if
not
self
.
_check_format
(
model_output
):
tool_call_matches
=
self
.
tool_call_complete_regex
.
findall
(
model_output
)
if
len
(
tool_call_matches
)
==
0
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_fix_missing_equals_in_function_tag
(
self
,
chunk
:
str
)
->
str
:
"""
Fix missing = in function tags: <function xxx> or <functionxxx>
Examples:
<function execute_bash> -> <function=execute_bash>
<functionexecute_bash> -> <function=execute_bash>
Only fixes if function name exists in tool definition
"""
# already correct
if
"<function="
in
chunk
:
return
chunk
# Pattern 1: <function xxx> (with space/newline but no =)
pattern1
=
r
"<function\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match1
=
re
.
search
(
pattern1
,
chunk
)
if
match1
:
func_name
=
match1
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match1
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
# Pattern 2: <functionxxx> (no space, no =)
# only match <function followed by letters
pattern2
=
r
"<function([a-zA-Z_][a-zA-Z0-9_]*)\s*>"
match2
=
re
.
search
(
pattern2
,
chunk
)
if
match2
:
func_name
=
match2
.
group
(
1
).
strip
()
# must validate function name exists before fixing
if
func_name
and
self
.
_validate_function_name
(
func_name
):
original
=
match2
.
group
(
0
)
fixed
=
f
"<function=
{
func_name
}
>"
chunk
=
chunk
.
replace
(
original
,
fixed
,
1
)
return
chunk
return
chunk
def
_validate_function_name
(
self
,
func_name
:
str
)
->
bool
:
"""Check if function name exists in tool definitions"""
if
not
self
.
tools
:
return
False
try
:
function_calls
=
self
.
_get_function_calls
(
model_output
)
if
len
(
function_calls
)
==
0
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
for
tool
in
self
.
tools
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
func_name
):
return
True
tool_calls
:
list
[
ToolCall
]
=
[]
for
function_call_str
in
function_calls
:
tool_call
=
self
.
_parse_xml_function_call
(
function_call_str
,
request
.
tools
)
if
tool_call
:
tool_calls
.
append
(
tool_call
)
if
not
tool_calls
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
for
tool_call
in
tool_calls
:
if
(
not
tool_call
.
function
or
tool_call
.
function
.
arguments
is
None
or
not
self
.
_is_valid_json_arguments
(
tool_call
.
function
.
arguments
)
):
logger
.
warning
(
"Invalid JSON arguments in tool call, falling back to content."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
return
False
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
_update_prev_tool_call_state
(
tool_calls
)
def
_validate_parameter_name
(
self
,
param_name
:
str
)
->
bool
:
"""Check if parameter exists in current function's tool definition"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
True
# Extract content before tool calls
content_index
=
model_output
.
find
(
self
.
tool_call_start_token
)
content
=
model_output
[:
content_index
]
# .rstrip()
for
tool
in
self
.
tools
:
if
(
hasattr
(
tool
,
"type"
)
and
tool
.
type
==
"function"
and
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
and
tool
.
function
.
name
==
self
.
current_function_name
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
True
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
):
properties
=
params
.
get
(
"properties"
,
params
)
return
param_name
in
properties
break
return
ExtractedToolCallInformation
(
tools_called
=
(
len
(
tool_calls
)
>
0
),
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
,
)
return
True
except
Exception
:
logger
.
warning
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_should_skip_element
(
self
,
element
:
str
)
->
bool
:
"""
Determine whether an element should be skipped
def
_find_first_complete_tool_call_end
(
self
,
text
:
str
,
start_pos
:
int
=
0
)
->
int
:
"""Find the end position of the first complete tool call.
Args:
element: Element to evaluate
text: Text to search in
start_pos: Position to start searching from
Returns:
bool: True means should skip, False means should process
"""
Position after the first </tool_call> tag, or -1 if incomplete
# If it's a tool_call XML tag, don't skip
if
(
element
.
startswith
(
self
.
tool_call_start_token
)
or
element
.
startswith
(
self
.
function_start_token
)
or
element
.
startswith
(
self
.
parameter_start_token
)
):
return
False
# If currently not parsing tool calls and not blank,
# collect this text instead of skipping
# Only process other XML elements after tool_call appears,
# otherwise treat as plain text
if
self
.
current_call_id
is
None
and
element
:
# Collect text content to buffer
self
.
text_content_buffer
+=
element
return
True
# Still skip, but content has been collected
# If currently parsing tool calls,
# this might be parameter value, don't skip
if
self
.
current_call_id
is
not
None
:
return
False
Example:
"<tool_call>...</tool_call>..." returns position after </tool_call>
"""
# Find tool call start
start_idx
=
text
.
find
(
self
.
tool_call_start_token
,
start_pos
)
if
start_idx
==
-
1
:
return
-
1
# Find matching end token
end_idx
=
text
.
find
(
self
.
tool_call_end_token
,
start_idx
+
len
(
self
.
tool_call_start_token
)
)
if
end_idx
==
-
1
:
return
-
1
# Incomplete tool call
#
Skip blank cont
en
t
return
not
element
#
Return position after end tok
en
return
end_idx
+
len
(
self
.
tool_call_end_token
)
def
_find_next_complete_element
(
self
,
start_pos
:
int
)
->
tuple
[
str
|
None
,
int
]:
"""
Find next complete XML element from specified position
def
_find_tool_call_start
(
self
,
text
:
str
,
start_pos
:
int
=
0
)
->
int
:
"""Find the start position of next tool call.
Args:
start_pos: Position to start searching
text: Text to search in
start_pos: Position to start searching from
Returns:
(Complete element string, element end position),
returns (None, start_pos) if no complete element found
Position of <tool_call> token, or -1 if not found
"""
buffer
=
self
.
streaming_buffer
[
start_pos
:]
return
text
.
find
(
self
.
tool_call_start_token
,
start_pos
)
if
not
buffer
:
return
None
,
start_pos
def
_extract_content_between_tool_calls_list
(
self
,
text
:
str
)
->
list
[
str
]
:
"""Extract content segments after each tool call.
if
buffer
.
startswith
(
"<"
):
# Check if this is an incomplete parameter/function tag
# e.g., <parameter=-C: or <function=xxx
is_incomplete_param
=
(
buffer
.
startswith
(
"<parameter="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
is_incomplete_func
=
(
buffer
.
startswith
(
"<function="
)
and
">"
not
in
buffer
.
split
(
"
\n
"
)[
0
]
)
For n tool calls, returns n segments where segment[i] is the content
after tool_call[i] (before tool_call[i+1] or at the end).
if
is_incomplete_param
or
is_incomplete_func
:
# Find the corresponding closing tag
tag_type
=
"parameter"
if
is_incomplete_param
else
"function"
closing_tag
=
f
"</
{
tag_type
}
>"
closing_pos
=
buffer
.
find
(
closing_tag
)
if
closing_pos
!=
-
1
:
# Found closing tag, return complete element including closing tag
complete_element
=
buffer
[:
closing_pos
+
len
(
closing_tag
)]
return
complete_element
,
start_pos
+
closing_pos
+
len
(
closing_tag
)
# Need to ensure no new < appears,
# find the nearest one between < and >
tag_end
=
buffer
.
find
(
"<"
,
1
)
tag_end2
=
buffer
.
find
(
">"
,
1
)
if
tag_end
!=
-
1
and
tag_end2
!=
-
1
:
# Next nearest is <
if
tag_end
<
tag_end2
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
# Next nearest is >, means found XML element
else
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
elif
tag_end
!=
-
1
:
return
buffer
[:
tag_end
],
start_pos
+
tag_end
elif
tag_end2
!=
-
1
:
return
buffer
[:
tag_end2
+
1
],
start_pos
+
tag_end2
+
1
else
:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> or <function=
if
self
.
current_call_id
is
None
:
# Check if might be start of <tool_call>
if
buffer
==
"<tool_call>"
[:
len
(
buffer
)]:
# Might be start of <tool_call>, wait for more data
return
None
,
start_pos
elif
(
buffer
.
startswith
(
"<function="
)
or
buffer
==
"<function="
[:
len
(
buffer
)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return
None
,
start_pos
else
:
# Not start of <tool_call> or <function=, treat as text
return
buffer
,
start_pos
+
len
(
buffer
)
else
:
# When parsing tool calls,
# wait for more data to get complete tag
return
None
,
start_pos
else
:
# Find text content (until next < or buffer end)
next_tag_pos
=
buffer
.
find
(
"<"
)
if
next_tag_pos
!=
-
1
:
# Found text content
text_content
=
buffer
[:
next_tag_pos
]
return
text_content
,
start_pos
+
next_tag_pos
else
:
# Buffer end is all text, process
# (no longer wait for more data)
remaining
=
buffer
return
remaining
,
start_pos
+
len
(
remaining
)
def
_merge_new_deltas_to_single_response
(
self
,
initial_count
:
int
)
->
DeltaMessage
:
"""
Merge newly generated deltas from this processing
into a single DeltaMessage
Empty or whitespace-only segments are represented as empty string "".
Args:
initial_count: Delta
co
u
nt
before processing
text: Text
cont
aining tool calls
Returns:
Merged DeltaMessage containing all newly generated delta information
List of content segments (one per tool call)
"""
if
len
(
self
.
deltas
)
<=
initial_count
:
return
DeltaMessage
(
content
=
None
)
content_segments
=
[]
pos
=
0
# Get newly generated deltas
new_deltas
=
self
.
deltas
[
initial_count
:]
while
True
:
# Find end of current tool call
end_pos
=
text
.
find
(
self
.
tool_call_end_token
,
pos
)
if
end_pos
==
-
1
:
break
if
len
(
new_deltas
)
==
1
:
# Only one new delta, return directly
return
new_deltas
[
0
]
# Move past the end token
end_pos
+=
len
(
self
.
tool_call_end_token
)
# Merge multiple new deltas
merged_tool_calls
:
list
[
DeltaToolCall
]
=
[]
merged_content
:
str
=
""
# Find start of next tool call
next_start
=
self
.
_find_tool_call_start
(
text
,
end_pos
)
for
delta
in
new_deltas
:
if
delta
.
content
:
merged_content
+=
delta
.
content
if
delta
.
tool_calls
:
# For tool_calls, we need to intelligently merge arguments
for
tool_call
in
delta
.
tool_calls
:
# Find if there's already a tool_call with the same call_id
existing_call
=
None
for
existing
in
merged_tool_calls
:
if
existing
.
id
==
tool_call
.
id
:
existing_call
=
existing
break
if
existing_call
and
existing_call
.
function
:
# Merge to existing tool_call
if
tool_call
.
function
and
tool_call
.
function
.
name
:
existing_call
.
function
.
name
=
tool_call
.
function
.
name
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
is
not
None
):
if
existing_call
.
function
.
arguments
is
None
:
existing_call
.
function
.
arguments
=
""
# For streaming JSON parameters,
# simply concatenate in order
new_args
=
tool_call
.
function
.
arguments
existing_call
.
function
.
arguments
+=
new_args
if
tool_call
.
type
:
existing_call
.
type
=
tool_call
.
type
else
:
# Add new tool_call
merged_tool_calls
.
append
(
tool_call
)
# Extract content between current end and next start (or text end)
content
=
text
[
end_pos
:
next_start
]
if
next_start
!=
-
1
else
text
[
end_pos
:]
return
DeltaMessage
(
content
=
merged_content
if
merged_content
else
None
,
tool_calls
=
merged_tool_calls
,
)
def
_preprocess_xml_chunk
(
self
,
chunk
:
str
)
->
str
:
"""
Preprocess XML chunk, handle non-standard formats,
and escape special characters
# Store content (empty string if whitespace-only)
content_segments
.
append
(
content
if
content
.
strip
()
else
""
)
Args:
chunk: Original XML chunk
if
next_start
==
-
1
:
break
pos
=
next_start
Returns:
Processed XML chunk
"""
return
content_segments
# Check if this is a tool_call related element
is_tool_call
=
False
if
chunk
.
startswith
(
self
.
tool_call_start_token
)
or
chunk
.
startswith
(
self
.
tool_call_end_token
):
is_tool_call
=
True
# Check for function tags (including malformed ones without =)
# <function=xxx>, </function>, <function xxx>, <functionxxx>
if
(
chunk
.
startswith
(
self
.
function_start_token
)
or
chunk
.
startswith
(
self
.
function_end_token
)
or
chunk
.
startswith
(
"<function "
)
or
re
.
match
(
r
"^<function[a-zA-Z_]"
,
chunk
)
):
# <functionXXX without space or =
is_tool_call
=
True
if
chunk
.
startswith
(
self
.
parameter_start_token
)
or
chunk
.
startswith
(
self
.
parameter_end_token
):
is_tool_call
=
True
def
_convert_tool_calls_to_deltas
(
self
,
tool_calls
:
list
[
ToolCall
],
starting_index
:
int
=
0
)
->
list
[
DeltaToolCall
]:
"""Convert complete ToolCall list to DeltaToolCall list.
# Fallback: fix incomplete <parameter= or <function= tags without
# closing >
# This handles cases like: <parameter=-C:\n or <parameter=-B 5\n
# Apply when parsing tool calls OR when chunk looks like a function/
# parameter tag
if
(
self
.
current_call_id
is
not
None
or
chunk
.
startswith
(
"<function"
)
or
chunk
.
startswith
(
"<parameter"
)
):
chunk
=
self
.
_fix_incomplete_tag_in_chunk
(
chunk
)
# Handle <function=name> format -> <function name="name">
processed
=
re
.
sub
(
r
"<function=([^>]+)>"
,
r
'<function name="\1">'
,
chunk
)
# Handle <parameter=name> format -> <parameter name="name">
processed
=
re
.
sub
(
r
"<parameter=([^>]+)>"
,
r
'<parameter name="\1">'
,
processed
)
original_chunk
=
chunk
# If in parameter value accumulation mode
if
self
.
_pre_inside_parameter
:
# Parameter end: output accumulated raw text
# safely then return </parameter>
if
processed
.
startswith
(
"</parameter>"
):
body_text
=
self
.
_pre_param_buffer
# Trigger deferred parsing mode
# literal_eval+json output in end_element
self
.
defer_current_parameter
=
True
self
.
deferred_param_raw_value
=
body_text
# Clean up state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
safe_text
=
self
.
_escape_xml_special_chars
(
body_text
)
return
f
"
{
safe_text
}
</parameter>"
else
:
# If this is the first block of content after entering parameter
# evaluate if deferred parsing is needed;
# If not needed, exit accumulation mode
# and pass through directly
if
self
.
_pre_param_buffer
==
""
:
# Get current parameter type
param_type
=
(
self
.
_get_param_type
(
self
.
_pre_current_param_name
)
if
self
.
_pre_current_param_name
else
"string"
)
# Only these types need deferred parsing to
# handle Python literals containing single quotes
is_object_type
=
param_type
in
[
"object"
]
is_complex_type
=
(
param_type
in
[
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
Returns complete tool calls without splitting into fragments.
# Only delay when contains container symbols
# and has single quotes and is complex type
has_container_hint
=
(
(
"["
in
original_chunk
)
or
(
"{"
in
original_chunk
)
or
(
"("
in
original_chunk
)
)
Args:
tool_calls: List of tool calls to convert
starting_index: Starting index for tool calls (default 0)
# Determine if deferred parsing is needed
need_defer
=
False
if
is_complex_type
:
# Complex type, always need deferred parsing
need_defer
=
True
elif
(
is_object_type
and
has_container_hint
and
(
"'"
in
original_chunk
)
):
# Object type with container symbols
# and single quotes, need deferred parsing
need_defer
=
True
if
not
need_defer
:
# No need for deferred parsing,
# exit parameter mode directly
self
.
_pre_inside_parameter
=
False
return
self
.
_escape_xml_special_chars
(
original_chunk
)
self
.
_pre_param_buffer
+=
original_chunk
return
""
# Parameter start: enable accumulation
if
processed
.
startswith
(
"<parameter name="
):
m
=
re
.
match
(
r
'<parameter name="([^"]+)">'
,
processed
)
if
m
:
self
.
_pre_current_param_name
=
m
.
group
(
1
)
self
.
_pre_inside_parameter
=
True
self
.
_pre_param_buffer
=
""
return
processed
# If processed doesn't contain special_token, escape processed
# This is because XML parsing encounters special characters
# and reports errors, so escaping is needed
if
not
is_tool_call
:
processed
=
self
.
_escape_xml_special_chars
(
processed
)
return
processed
def
_emit_delta
(
self
,
delta
:
DeltaMessage
):
"""Emit Delta response (streaming output)"""
self
.
deltas
.
append
(
delta
)
def
_auto_close_open_parameter_if_needed
(
self
,
incoming_tag
:
str
|
None
=
None
):
"""Before starting to process new elements,
if there are unclosed tags from before,
automatically complete their endings to the parser.
- If there are unclosed parameters,
it's equivalent to feeding `</parameter>`
- When about to start a new function or tool_call,
if there are unclosed functions, complete `</function>`.
- When about to start a new tool_call,
if there are unclosed tool_calls, complete `</tool_call>`.
Returns:
List of DeltaToolCall with complete arguments
"""
# First close unclosed parameters
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# If about to start new function or tool_call,
# and there are unclosed functions, close function first
if
incoming_tag
in
(
"function"
,
"tool_call"
)
and
self
.
current_function_name
:
self
.
_end_element
(
"function"
)
# If about to start new tool_call,
# and there are unclosed tool_calls, close tool_call first
if
incoming_tag
==
"tool_call"
and
self
.
current_call_id
:
self
.
_end_element
(
"tool_call"
)
def
_start_element
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
]):
"""Handle XML start element events"""
if
name
==
"root"
:
return
if
name
==
"tool_call"
:
# Before opening new tool_call,
# automatically complete previous unclosed tags
self
.
_auto_close_open_parameter_if_needed
(
"tool_call"
)
self
.
parameters
=
{}
self
.
current_call_id
=
make_tool_call_id
()
self
.
current_param_is_first
=
True
self
.
tool_call_index
+=
1
elif
name
.
startswith
(
"function"
)
or
(
name
==
"function"
):
# If missing tool_call, manually complete
if
not
self
.
current_call_id
:
self
.
_start_element
(
"tool_call"
,
{})
# Before opening new function,
# automatically complete previous unclosed tags (parameter/function)
self
.
_auto_close_open_parameter_if_needed
(
"function"
)
function_name
=
self
.
_extract_function_name
(
name
,
attrs
)
self
.
current_function_name
=
function_name
self
.
current_function_open
=
True
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
function_name
,
arguments
=
""
),
)
]
)
self
.
_emit_delta
(
delta
)
elif
name
.
startswith
(
"parameter"
)
or
(
name
==
"parameter"
):
# If previous parameter hasn't ended normally,
# complete its end first, then start new parameter
self
.
_auto_close_open_parameter_if_needed
(
"parameter"
)
param_name
=
self
.
_extract_parameter_name
(
name
,
attrs
)
self
.
current_param_name
=
param_name
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
# Reset start quote flag
# Only output parameter name and colon,
# don't output quotes
# decide after parameter value type is determined
if
param_name
:
if
not
self
.
parameters
:
# First parameter
# start JSON, only output parameter name and colon
json_start
=
f
'{{"
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_start
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
True
else
:
# Subsequent parameters
# add comma and parameter name, no quotes
json_continue
=
f
', "
{
param_name
}
": '
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
json_continue
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_param_is_first
=
False
def
_char_data
(
self
,
data
:
str
):
"""Handle XML character data events"""
if
data
and
self
.
current_param_name
:
# If preprocessing stage determines deferred parsing is needed,
# only cache character data, no streaming output
if
self
.
defer_current_parameter
:
original_data
=
data
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
return
param_type
=
self
.
_get_param_type
(
self
.
current_param_name
)
# Check if this is the first time receiving data for this parameter
# If this is the first packet of data and starts with \n, remove \n
if
not
self
.
current_param_value
and
data
.
startswith
(
"
\n
"
):
data
=
data
[
1
:]
# Output start quote for string type (if not already output)
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
and
not
self
.
start_quote_emitted
):
quote_delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
delta_tool_calls
=
[]
for
i
,
tool_call
in
enumerate
[
ToolCall
](
tool_calls
):
index
=
starting_index
+
i
tool_id
=
self
.
_generate_tool_call_id
()
# Create complete DeltaToolCall with full arguments
delta_tool_calls
.
append
(
DeltaToolCall
(
index
=
index
,
id
=
tool_id
,
function
=
DeltaFunctionCall
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
type
=
"function"
,
)
self
.
_emit_delta
(
quote_delta
)
self
.
start_quote_emitted
=
True
if
not
data
:
return
original_data
=
data
# Delay output of trailing newline
if
self
.
should_emit_end_newline
:
original_data
=
"
\n
"
+
original_data
self
.
should_emit_end_newline
=
False
if
original_data
.
endswith
(
"
\n
"
):
self
.
should_emit_end_newline
=
True
original_data
=
original_data
[:
-
1
]
self
.
current_param_value
+=
original_data
# convert parameter value by param_type
converted_value
=
self
.
_convert_param_value
(
self
.
current_param_value
,
param_type
)
output_data
=
self
.
_convert_for_json_streaming
(
converted_value
,
param_type
)
delta_data
=
output_data
[
len
(
self
.
current_param_value_converted
)
:]
self
.
current_param_value_converted
=
output_data
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
delta_data
),
)
]
)
self
.
_emit_delta
(
delta
)
def
_end_element
(
self
,
name
:
str
):
"""Handle XML end element events"""
return
delta_tool_calls
if
name
==
"root"
:
return
# If function or tool_call ends and there are still unclosed parameters,
# complete parameter end first
if
(
name
.
startswith
(
"function"
)
or
name
==
"function"
or
name
==
"tool_call"
)
and
self
.
current_param_name
:
self
.
_auto_close_open_parameter_if_needed
()
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if
not
previous_text
:
self
.
_reset_streaming_state
()
self
.
streaming_request
=
request
# Check for EOS token
has_eos
=
(
self
.
eos_token_id
is
not
None
and
delta_token_ids
and
self
.
eos_token_id
in
delta_token_ids
)
if
(
name
.
startswith
(
"parameter"
)
or
name
==
"parameter"
)
and
self
.
current_param_name
:
# End current parameter
param_name
=
self
.
current_param_name
param_value
=
self
.
current_param_value
# If in deferred parsing mode,
# perform overall parsing on raw content
# accumulated in preprocessing stage and output once
if
self
.
defer_current_parameter
:
raw_text
=
(
self
.
deferred_param_raw_value
if
self
.
deferred_param_raw_value
else
param_value
)
parsed_value
=
None
output_arguments
=
None
try
:
# If previously delayed trailing newline,
# add it back before parsing
if
self
.
should_emit_end_newline
:
raw_for_parse
=
raw_text
+
"
\n
"
else
:
raw_for_parse
=
raw_text
parsed_value
=
ast
.
literal_eval
(
raw_for_parse
)
output_arguments
=
json
.
dumps
(
parsed_value
,
ensure_ascii
=
False
)
except
Exception
:
# Fallback: output as string as-is
output_arguments
=
json
.
dumps
(
raw_text
,
ensure_ascii
=
False
)
parsed_value
=
raw_text
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
output_arguments
),
)
]
# If no delta text, check if we need to return empty delta for finish_reason
if
not
delta_text
and
not
has_eos
:
# Check if this is an EOS token after all tool calls are complete
if
delta_token_ids
and
self
.
tool_call_end_token_id
not
in
delta_token_ids
:
# Count complete tool calls
complete_calls
=
len
(
self
.
tool_call_complete_regex
.
findall
(
current_text
)
)
self
.
_emit_delta
(
delta
)
# Clean up and store
self
.
should_emit_end_newline
=
False
self
.
parameters
[
param_name
]
=
parsed_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
return
param_type
=
self
.
_get_param_type
(
param_name
)
# convert complete parameter value by param_type
converted_value
=
self
.
_convert_param_value
(
param_value
,
param_type
)
# Decide whether to add end quote based on parameter type
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# For empty string parameters, need special handling
if
not
param_value
and
not
self
.
start_quote_emitted
:
# No start quote output,
# directly output complete empty string
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'""'
),
)
]
)
self
.
_emit_delta
(
delta
)
else
:
# Non-empty parameter value, output end quote
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
'"'
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
should_emit_end_newline
=
False
# Store converted value
self
.
parameters
[
param_name
]
=
converted_value
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
start_quote_emitted
=
False
elif
name
.
startswith
(
"function"
)
or
name
==
"function"
:
# if there are parameters, close JSON object
if
self
.
parameters
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"}"
),
)
]
)
self
.
_emit_delta
(
delta
)
# return empty object
else
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
"{}"
),
)
]
)
self
.
_emit_delta
(
delta
)
self
.
current_function_open
=
False
self
.
current_function_name
=
(
None
# Clear function name to prevent duplicate closing
)
elif
name
==
"tool_call"
:
# Before ending tool_call,
# ensure function is closed to complete missing right brace
if
self
.
current_function_open
:
# If there are still unclosed parameters, close them first
if
self
.
current_param_name
:
self
.
_end_element
(
"parameter"
)
# Close function, ensure output '}' or '{}'
self
.
_end_element
(
"function"
)
# Final Delta
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
tool_call_index
-
1
,
id
=
self
.
current_call_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
None
,
arguments
=
""
),
)
]
)
self
.
_emit_delta
(
delta
)
# If we have completed tool calls and populated prev_tool_call_arr
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta for finish_reason processing
return
DeltaMessage
(
content
=
""
)
return
None
# Check if there's text content to output (between tool_calls)
if
self
.
text_content_buffer
.
strip
():
text_delta
=
DeltaMessage
(
content
=
self
.
text_content_buffer
)
self
.
_emit_delta
(
text_delta
)
# Process all available content
accumulated_deltas
:
list
[
DeltaMessage
]
=
[]
self
.
_reset_xml_parser_after_tool_call
()
while
self
.
_has_unprocessed_content
(
current_text
):
# Try to process next chunk (content or tool call)
delta
=
self
.
_process_next_chunk
(
current_text
)
def
setup_parser
(
self
):
"""Set up XML parser event handlers"""
self
.
parser
.
buffer_text
=
True
self
.
parser
.
StartElementHandler
=
self
.
_start_element
self
.
parser
.
EndElementHandler
=
self
.
_end_element
self
.
parser
.
CharacterDataHandler
=
self
.
_char_data
if
delta
is
None
:
# Cannot proceed further, need more tokens
break
def
set_tools
(
self
,
tools
:
list
[
ChatCompletionToolsParam
]
|
None
):
"""Set tool configuration information"""
self
.
tools
=
tools
# Accumulate deltas
if
isinstance
(
delta
,
list
):
accumulated_deltas
.
extend
(
delta
)
else
:
accumulated_deltas
.
append
(
delta
)
# Handle EOS: flush any remaining incomplete tool calls as content
if
has_eos
:
remaining_delta
=
self
.
_flush_remaining_content
(
current_text
)
if
remaining_delta
:
accumulated_deltas
.
append
(
remaining_delta
)
# If no remaining content but we have tool calls, return empty delta
elif
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
accumulated_deltas
.
append
(
DeltaMessage
(
content
=
""
))
# Return results
return
self
.
_format_delta_result
(
accumulated_deltas
)
def
_has_unprocessed_content
(
self
,
current_text
:
str
)
->
bool
:
"""Check if there's unprocessed content in the buffer."""
return
self
.
_processed_length
<
len
(
current_text
)
def
_process_next_chunk
(
self
,
current_text
:
str
)
->
DeltaMessage
|
list
[
DeltaMessage
]
|
None
:
"""Process next chunk: either regular content or a complete tool call.
def
_extract_function_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract function name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
Args:
current_text: Current accumulated text
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"function"
:
return
parts
[
1
]
Returns:
- DeltaMessage or list of DeltaMessage if processed successfully
- None if cannot proceed (need more tokens)
"""
# Find next tool call start
tool_start_idx
=
self
.
_find_tool_call_start
(
current_text
,
self
.
_processed_length
)
return
None
# Case 1: No tool call found - return remaining content
if
tool_start_idx
==
-
1
:
return
self
.
_process_content
(
current_text
,
self
.
_processed_length
,
len
(
current_text
)
)
def
_extract_parameter_name
(
self
,
name
:
str
,
attrs
:
dict
[
str
,
str
])
->
str
|
None
:
"""Extract parameter name from various formats"""
if
attrs
and
"name"
in
attrs
:
return
attrs
[
"name"
]
# Case 2: Content before tool call
if
tool_start_idx
>
self
.
_processed_length
:
return
self
.
_process_content
(
current_text
,
self
.
_processed_length
,
tool_start_idx
)
if
"="
in
name
:
parts
=
name
.
split
(
"="
,
1
)
if
len
(
parts
)
==
2
and
parts
[
0
]
==
"parameter"
:
return
parts
[
1
]
# Case 3: Tool call at current position
# Find end of the first complete tool call
tool_end_idx
=
self
.
_find_first_complete_tool_call_end
(
current_text
,
tool_start_idx
)
return
None
if
tool_end_idx
==
-
1
:
# Tool call incomplete, wait for more tokens
return
None
# Process complete tool call
return
self
.
_process_complete_tool_calls
(
current_text
,
tool_start_idx
,
tool_end_idx
)
def
_process_content
(
self
,
current_text
:
str
,
start_pos
:
int
,
end_pos
:
int
)
->
DeltaMessage
|
None
:
"""Process regular content (non-tool-call text).
def
_get_param_type
(
self
,
param_name
:
str
)
->
str
:
"""Get parameter type based on tool configuration, defaults to string
Args:
param_name: Parameter name
current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
Returns:
Parameter type
DeltaMessage with content if non-empty
"""
if
not
self
.
tools
or
not
self
.
current_function_name
:
return
"string"
if
start_pos
>=
end_pos
:
return
None
for
tool
in
self
.
tools
:
if
not
hasattr
(
tool
,
"type"
)
or
not
(
hasattr
(
tool
,
"function"
)
and
hasattr
(
tool
.
function
,
"name"
)
):
continue
content
=
current_text
[
start_pos
:
end_pos
]
# Check if we're between tool calls - skip whitespace
if
start_pos
>
0
:
# Check if text before start_pos ends with </tool_call>
text_before
=
current_text
[:
start_pos
]
if
(
t
ool
.
type
==
"function"
and
tool
.
function
.
name
==
self
.
current_function_name
t
ext_before
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
content
.
strip
()
==
""
):
if
not
hasattr
(
tool
.
function
,
"parameters"
):
return
"string"
params
=
tool
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
properties
=
params
[
"properties"
]
if
param_name
in
properties
and
isinstance
(
properties
[
param_name
],
dict
):
return
self
.
repair_param_type
(
str
(
properties
[
param_name
].
get
(
"type"
,
"string"
))
)
elif
isinstance
(
params
,
dict
)
and
param_name
in
params
:
param_config
=
params
[
param_name
]
if
isinstance
(
param_config
,
dict
):
return
self
.
repair_param_type
(
str
(
param_config
.
get
(
"type"
,
"string"
))
)
break
return
"string"
# We just ended a tool call, skip whitespace between tool calls
self
.
_processed_length
=
end_pos
return
None
def
repair_param_type
(
self
,
param_type
:
str
)
->
str
:
"""Repair unknown parameter types by treating them as string
Args:
param_type: Parameter type
# Return content if non-empty
if
content
:
self
.
_processed_length
=
end_pos
return
DeltaMessage
(
content
=
content
)
Returns:
Repaired parameter type
"""
if
(
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]
or
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
or
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
)
or
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]
or
(
param_type
in
[
"object"
,
"array"
,
"arr"
,
"sequence"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
)
)
):
return
param_type
else
:
return
"string"
# Mark as processed even if empty
self
.
_processed_length
=
end_pos
return
None
def
_flush_remaining_content
(
self
,
current_text
:
str
)
->
DeltaMessage
|
None
:
"""Flush any remaining unprocessed content as regular content.
def
_convert_param_value
(
self
,
param_value
:
str
,
param_type
:
str
)
->
Any
:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
current_text: Current accumulated text
Returns:
Converted value
Used when EOS token is encountered to handle incomplete tool calls.
"""
if
param_value
.
lower
()
==
"null"
:
if
not
self
.
_has_unprocessed_content
(
current_text
)
:
return
None
param_type
=
param_type
.
strip
().
lower
()
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
):
try
:
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not an integer, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
:
float
=
float
(
param_value
)
return
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' is not a float, degenerating to string."
,
param_value
,
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
param_value
=
param_value
.
lower
()
return
param_value
==
"true"
else
:
return
param_value
remaining
=
current_text
[
self
.
_processed_length
:]
if
remaining
:
self
.
_processed_length
=
len
(
current_text
)
return
DeltaMessage
(
content
=
remaining
)
self
.
_processed_length
=
len
(
current_text
)
return
None
def
_format_delta_result
(
self
,
deltas
:
list
[
DeltaMessage
])
->
DeltaMessage
|
None
:
"""Format delta result for return.
Merges all deltas into a single DeltaMessage.
def
_convert_for_json_streaming
(
self
,
converted_value
:
Any
,
param_type
:
str
)
->
str
:
"""Convert converted_value based on
whether it's empty and if type is string
Args:
converted_value: Converted value
param_type: Parameter type
deltas: List of delta messages
Returns:
Converted string for streaming output
- None if empty
- Single merged DeltaMessage with all content and tool_calls
"""
# Check if value is empty, but exclude numeric 0
if
converted_value
is
None
or
converted_value
==
""
:
return
""
if
not
deltas
:
return
None
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
# String type, remove double quotes
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)[
1
:
-
1
]
else
:
# Non-string type, return complete JSON string
if
not
isinstance
(
converted_value
,
str
):
return
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
else
:
return
converted_value
if
len
(
deltas
)
==
1
:
return
deltas
[
0
]
def
_reset_xml_parser_after_tool_call
(
self
):
"""
Each tool_call is treated as a separate XML document,
so we need to reset the parser after each tool_call.
"""
# Merge multiple deltas into one
merged_content_parts
=
[]
merged_tool_calls
=
[]
# recreate XML parser
self
.
parser
=
ParserCreate
()
self
.
setup_parser
()
# Reset current tool_call state
if
self
.
current_call_id
:
self
.
last_completed_call_id
=
self
.
current_call_id
self
.
current_call_id
=
None
self
.
current_function_name
=
None
self
.
current_function_open
=
False
self
.
parameters
=
{}
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
current_param_value_converted
=
""
self
.
current_param_is_first
=
False
self
.
should_emit_end_newline
=
False
self
.
start_quote_emitted
=
False
self
.
text_content_buffer
=
""
# Reset preprocessing and deferred parsing state
self
.
_pre_inside_parameter
=
False
self
.
_pre_param_buffer
=
""
self
.
_pre_current_param_name
=
None
self
.
defer_current_parameter
=
False
self
.
deferred_param_raw_value
=
""
for
delta
in
deltas
:
if
delta
.
content
:
merged_content_parts
.
append
(
delta
.
content
)
if
delta
.
tool_calls
:
merged_tool_calls
.
extend
(
delta
.
tool_calls
)
# Create merged DeltaMessage
merged_content
=
""
.
join
(
merged_content_parts
)
if
merged_content_parts
else
None
class
Step3p5ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
self
.
parser
=
StreamingXMLT
ool
C
all
Parser
()
# Build kwargs - only include tool_calls if non-empty
kwargs
:
dict
[
str
,
Any
]
=
{
"content"
:
merged_content
}
if
merged_tool_calls
:
kwargs
[
"tool_calls"
]
=
merged_t
ool
_c
all
s
# Add missing attributes for compatibility with serving_chat.py
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
return
DeltaMessage
(
**
kwargs
)
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_process_complete_tool_calls
(
self
,
current_text
:
str
,
start_pos
:
int
,
end_pos
:
int
)
->
list
[
DeltaMessage
]
|
None
:
"""Process complete tool calls and convert to delta sequence.
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new extraction
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
result
=
self
.
parser
.
parse_single_streaming_chunks
(
model_output
)
if
not
result
.
tool_calls
:
return
ExtractedToolCallInformation
(
tool_calls
=
[],
tools_called
=
False
,
content
=
result
.
content
,
)
else
:
tool_calls
=
[]
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
and
tool_call
.
function
.
name
:
tool_calls
.
append
(
ToolCall
(
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
FunctionCall
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
)
)
Args:
current_text: Current accumulated text
start_pos: Start position (should be at <tool_call>)
end_pos: End position (after </tool_call>)
# Update tool call tracking arrays for compatibility
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_
tool
_
call
_arr
)
-
1
)
Returns:
List of DeltaMessage if successful, None otherwise
"""
try
:
# Extract text segment containing complete
tool
call
(s)
text_to_parse
=
current_text
[
start_pos
:
end_pos
]
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Parse using non-streaming method
result
=
self
.
extract_tool_calls_basic
(
text_to_parse
,
self
.
streaming_request
)
# Update tool call information
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
=
(
tool_call
.
function
.
arguments
)
# Case 1: Successfully parsed tool calls
if
result
.
tools_called
and
result
.
tool_calls
:
# Note: Due to _find_first_complete_tool_call_end, we typically
# process only one tool call at a time
# but we can also process multiple tool calls below
deltas
=
self
.
_build_tool_call_deltas
(
result
.
tool_calls
,
text_to_parse
)
self
.
_update_state_after_tool_calls
(
result
.
tool_calls
,
end_pos
)
return
deltas
if
deltas
else
None
# Case 2: Parsing failed - treat as regular content
self
.
_processed_length
=
end_pos
return
[
DeltaMessage
(
content
=
text_to_parse
)]
except
Exception
as
e
:
# Exception during parsing - treat as content
logger
.
debug
(
"Failed to parse tool calls: %s, treating as content"
,
e
)
self
.
_processed_length
=
end_pos
failed_text
=
current_text
[
start_pos
:
end_pos
]
return
[
DeltaMessage
(
content
=
failed_text
)]
if
failed_text
else
None
def
_build_tool_call_deltas
(
self
,
tool_calls
:
list
[
ToolCall
],
parsed_text
:
str
)
->
list
[
DeltaMessage
]:
"""Build delta messages from parsed tool calls with interleaved content.
# Update streamed arguments
if
tool_call
.
function
.
arguments
:
self
.
streamed_args_for_tool
[
tool_index
]
=
(
tool_call
.
function
.
arguments
)
Args:
tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
return
ExtractedToolCallInformation
(
tool
_
calls
=
tool_calls
,
tools_called
=
len
(
tool_calls
)
>
0
,
content
=
result
.
content
,
)
Returns:
List of DeltaMessage with
tool
calls
and content interleaved
"""
# Extract content segments between tool calls
content_segments
=
self
.
_extract_content_between_tool_calls_list
(
parsed_text
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
if
not
previous_text
:
self
.
parser
.
reset_streaming_state
()
# Reset tool call tracking arrays for new streaming session
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
if
request
:
self
.
parser
.
set_tools
(
request
.
tools
)
# Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended,
# return an empty tool_call for outer streaming output
# to correctly output tool_call field
if
not
delta_text
and
delta_token_ids
:
open_calls
=
current_text
.
count
(
self
.
parser
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
parser
.
tool_call_end_token
)
if
(
open_calls
==
0
and
self
.
parser
.
tool_call_index
>
0
or
not
self
.
parser
.
tool_call_index
and
current_text
):
return
DeltaMessage
(
content
=
""
)
return
None
# Convert all tool calls to DeltaToolCall list
delta_tool_calls
=
self
.
_convert_tool_calls_to_deltas
(
tool_calls
,
self
.
_tool_call_index
)
# Parse the delta text and get the result
result
=
self
.
parser
.
parse_single_streaming_chunks
(
delta_text
)
# Update tool call tracking arrays based on incremental parsing results
if
result
and
result
.
tool_calls
:
for
tool_call
in
result
.
tool_calls
:
if
tool_call
.
function
:
tool_index
=
(
tool_call
.
index
if
tool_call
.
index
is
not
None
else
len
(
self
.
prev_tool_call_arr
)
-
1
)
# Merge all content segments into a single string
merged_content
=
""
.
join
(
content_segments
)
# Ensure we have enough entries in our tracking arrays
while
len
(
self
.
prev_tool_call_arr
)
<=
tool_index
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
""
,
"arguments"
:
""
})
while
len
(
self
.
streamed_args_for_tool
)
<=
tool_index
:
self
.
streamed_args_for_tool
.
append
(
""
)
# Return a single DeltaMessage with all tool calls and content
# Build kwargs - only include non-empty fields
kwargs
:
dict
[
str
,
Any
]
=
{}
if
merged_content
:
kwargs
[
"content"
]
=
merged_content
if
delta_tool_calls
:
kwargs
[
"tool_calls"
]
=
delta_tool_calls
# Update tool name if provided
if
tool_call
.
function
.
name
:
self
.
prev_tool_call_arr
[
tool_index
][
"name"
]
=
(
tool_call
.
function
.
name
)
# Only return DeltaMessage if we have content or tool_calls
if
kwargs
:
return
[
DeltaMessage
(
**
kwargs
)]
else
:
return
[]
# Update arguments incrementally
if
tool_call
.
function
.
arguments
is
not
None
:
# Concatenate the incremental arguments
# to the existing streamed arguments
self
.
prev_tool_call_arr
[
tool_index
][
"arguments"
]
+=
(
tool_call
.
function
.
arguments
)
self
.
streamed_args_for_tool
[
tool_index
]
+=
(
tool_call
.
function
.
arguments
)
return
result
def
_update_state_after_tool_calls
(
self
,
tool_calls
:
list
[
ToolCall
],
end_pos
:
int
)
->
None
:
"""Update internal state after processing tool calls.
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Skip the remaining_call calculation in serving_chat
Args
:
tool_calls: List of processed tool calls
end_pos: End position in buffer
"""
return
False
# Update processed position
self
.
_processed_length
=
end_pos
# Update tool call index
self
.
_tool_call_index
+=
len
(
tool_calls
)
# Update prev_tool_call_arr for finish_reason
self
.
_update_prev_tool_call_state
(
tool_calls
)
vllm/v1/attention/backends/triton_attn.py
View file @
84ac1d27
...
...
@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
seq_threshold_3D
=
seq_threshold_3D
,
seq_threshold_3D
=
None
,
num_par_softmax_segments
=
num_par_softmax_segments
,
softmax_segm_output
=
softmax_segm_output
,
softmax_segm_max
=
softmax_segm_max
,
...
...
vllm/v1/core/kv_cache_utils.py
View file @
84ac1d27
...
...
@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def
_get_kv_cache_groups_uniform_page_size
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
)
->
list
[
KVCacheGroupSpec
]:
"""
...
...
@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size(
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
...
...
@@ -1062,19 +1069,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers
/
len
(
layers
)
*
100
,
)
num_groups
=
cdiv
(
len
(
layers
),
group_size
)
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for
i
in
range
(
num_groups
):
grouped_layers
.
append
(
layers
[
i
::
num_groups
])
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
enable_multi_layers_mtp
):
for
i
in
range
(
0
,
len
(
layers
),
group_size
):
grouped_layers
.
append
(
layers
[
i
:
i
+
group_size
])
else
:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for
i
in
range
(
num_groups
):
grouped_layers
.
append
(
layers
[
i
::
num_groups
])
return
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layers
)
...
...
@@ -1259,7 +1275,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return
_get_kv_cache_groups_uniform_page_size
(
kv_cache_spec
)
return
_get_kv_cache_groups_uniform_page_size
(
vllm_config
=
vllm_config
,
kv_cache_spec
=
kv_cache_spec
)
def
generate_scheduler_kv_cache_config
(
...
...
vllm/v1/spec_decode/eagle.py
View file @
84ac1d27
...
...
@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
UniformTypeKVCacheSpecs
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.sampler
import
_SAMPLING_EPS
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
,
SpecDecodeMetadata
from
vllm.v1.spec_decode.utils
import
(
PADDING_SLOT_ID
,
compute_new_slot_mapping
,
...
...
@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample
:
torch
.
Tensor
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
...
...
vllm/v1/spec_decode/metadata.py
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
numpy
as
np
...
...
@@ -64,3 +66,45 @@ class SpecDecodeMetadata:
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
)
@
dataclass
class
MultiLayerEagleMetadata
:
# [batch_size]
cached_len
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_token_ids
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num, hidden_size]
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_positions
:
torch
.
Tensor
|
None
=
None
@
classmethod
def
make_dummy
(
cls
,
layer_num
:
int
,
hidden_size
:
int
,
device
:
torch
.
device
,
)
->
"MultiLayerEagleMetadata"
:
cached_len
=
torch
.
zeros
((
1
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_token_ids
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int32
,
device
=
device
)
cached_hidden_states
=
torch
.
zeros
(
(
1
,
layer_num
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
device
)
cached_slot_mappings
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_positions
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
return
cls
(
cached_len
=
cached_len
,
cached_token_ids
=
cached_token_ids
,
cached_hidden_states
=
cached_hidden_states
,
cached_slot_mappings
=
cached_slot_mappings
,
cached_positions
=
cached_positions
,
)
vllm/v1/spec_decode/multi_layer_eagle.py
0 → 100644
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
import
torch
from
vllm.config
import
CUDAGraphMode
,
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
import
(
CommonAttentionMetadata
,
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
logger
=
init_logger
(
__name__
)
BLOCK_HIDDEN
=
128
BLOCK_TOKENS
=
128
class
MultiLayerEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
device
:
torch
.
device
,
runner
=
None
,
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
layer_num
:
int
=
getattr
(
self
.
speculative_config
.
draft_model_config
.
hf_text_config
,
"n_predict"
,
0
)
self
.
num_speculative_tokens
:
int
=
(
self
.
speculative_config
.
num_speculative_tokens
)
if
self
.
num_speculative_tokens
!=
self
.
layer_num
:
logger
.
warning_once
(
"For multi_layer_eagle, num_speculative_tokens "
"does not match layer_num, adjusting to layer_num"
)
self
.
num_speculative_tokens
=
self
.
layer_num
def
adjust_input
(
self
,
batch_size
:
int
,
target_token_ids
:
torch
.
Tensor
,
target_positions
:
torch
.
Tensor
,
target_hidden_states
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Any
]:
MAX_SHIFT
=
self
.
layer_num
assert
MAX_SHIFT
>
0
prev_token_ids
=
target_token_ids
.
clone
()
prev_positions
=
target_positions
.
clone
()
prev_hidden_states
=
target_hidden_states
.
clone
()
slot_mapping
=
common_attn_metadata
.
slot_mapping
start_token_indices
=
common_attn_metadata
.
query_start_loc
[:
-
1
]
end_token_indices
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
pos_for_shift
=
(
target_positions
[
0
]
if
target_positions
.
dim
()
==
2
else
target_positions
)
start_token_pos
=
pos_for_shift
[
start_token_indices
]
shift
=
torch
.
minimum
(
end_token_indices
-
token_indices_to_sample
,
start_token_pos
,
)
shift
=
torch
.
clamp
(
shift
,
min
=
0
)
# Metadata updates (matches the original reference implementation).
token_indices_to_sample
.
add_
(
shift
)
common_attn_metadata
.
seq_lens
.
sub_
(
shift
)
cached_lens
=
multi_layer_eagle_metadata
.
cached_len
shift
=
torch
.
minimum
(
shift
,
cached_lens
)
_multi_layer_eagle_shift_and_cache
(
batch_size
=
batch_size
,
max_shift
=
MAX_SHIFT
,
src_token_ids
=
target_token_ids
,
dst_token_ids
=
prev_token_ids
,
src_positions
=
target_positions
,
dst_positions
=
prev_positions
,
src_hidden_states
=
target_hidden_states
,
dst_hidden_states
=
prev_hidden_states
,
src_slot_mapping
=
slot_mapping
,
dst_slot_mapping
=
slot_mapping
,
start_token_indices
=
start_token_indices
,
end_token_indices
=
end_token_indices
,
token_indices_to_sample
=
token_indices_to_sample
,
shift
=
shift
,
cached_lens
=
cached_lens
,
cached_prev_token_ids
=
multi_layer_eagle_metadata
.
cached_token_ids
,
cached_prev_positions
=
multi_layer_eagle_metadata
.
cached_positions
,
cached_prev_hidden_states
=
multi_layer_eagle_metadata
.
cached_hidden_states
,
cached_slot_mappings
=
multi_layer_eagle_metadata
.
cached_slot_mappings
,
common_attn_metadata
=
common_attn_metadata
,
)
return
prev_token_ids
,
prev_positions
,
prev_hidden_states
,
common_attn_metadata
def
initial_inputs_for_forward
(
self
,
num_tokens
:
int
,
prev_token_ids
:
torch
.
Tensor
,
prev_positions
:
torch
.
Tensor
,
prev_hidden_states
:
torch
.
Tensor
,
next_token_ids
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
spec_step_idx
:
int
=
0
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
):
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
prev_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
token_indices_to_sample
]
=
next_token_ids
self
.
_set_positions
(
num_tokens
,
prev_positions
)
self
.
hidden_states
[:
num_tokens
]
=
prev_hidden_states
[:
num_tokens
]
if
self
.
supports_mm_inputs
:
mm_embeds
,
is_mm_embed
=
mm_embed_inputs
or
(
None
,
None
)
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
multimodal_embeddings
=
mm_embeds
,
is_multimodal
=
is_mm_embed
,
)
else
:
self
.
inputs_embeds
[:
num_tokens
]
=
self
.
model
.
embed_input_ids
(
self
.
input_ids
[:
num_tokens
],
)
def
draft_model_forward
(
self
,
num_tokens
:
int
,
per_layer_attn_metadata
:
dict
[
str
,
Any
],
token_indices_to_sample
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
spec_step_idx
:
int
=
0
,
):
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
if
self
.
supports_mm_inputs
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"hidden_states"
:
self
.
hidden_states
[:
num_input_tokens
],
"inputs_embeds"
:
inputs_embeds
,
"spec_step_idx"
:
spec_step_idx
,
}
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
):
last_hidden_states
=
self
.
model
(
**
model_kwargs
)
sample_hidden_states
=
last_hidden_states
[
token_indices_to_sample
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
spec_step_idx
=
spec_step_idx
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
return
draft_token_ids
,
last_hidden_states
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
=
None
,
)
->
torch
.
Tensor
:
assert
self
.
method
==
"mtp"
assert
self
.
runner
is
not
None
assert
multi_layer_eagle_metadata
is
not
None
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
if
token_indices_to_sample
is
None
:
token_indices_to_sample
=
common_attn_metadata
.
query_start_loc
[
1
:]
-
1
prev_token_ids
,
prev_positions
,
prev_hidden_states
,
common_attn_metadata
=
(
self
.
adjust_input
(
batch_size
=
batch_size
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_hidden_states
=
target_hidden_states
,
token_indices_to_sample
=
token_indices_to_sample
,
common_attn_metadata
=
common_attn_metadata
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
)
)
# Build per-layer attention metadata using draft attention groups
attn_metadata
=
None
for
attn_group
in
self
.
draft_attn_groups
:
attn_metadata
=
attn_group
.
get_metadata_builder
().
build_for_drafting
(
common_attn_metadata
=
common_attn_metadata
,
draft_index
=
0
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
_draft_attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
# Generate the remaining draft tokens.
draft_token_ids_list
:
list
[
torch
.
Tensor
]
=
[]
for
token_index
in
range
(
self
.
num_speculative_tokens
):
if
token_index
!=
0
:
prev_token_ids
=
self
.
input_ids
[:
num_tokens
].
clone
()
next_token_ids
=
draft_token_ids_list
[
-
1
].
int
()
self
.
initial_inputs_for_forward
(
num_tokens
=
num_tokens
,
prev_token_ids
=
prev_token_ids
,
prev_positions
=
prev_positions
,
prev_hidden_states
=
prev_hidden_states
,
next_token_ids
=
next_token_ids
,
token_indices_to_sample
=
token_indices_to_sample
,
spec_step_idx
=
token_index
,
mm_embed_inputs
=
mm_embed_inputs
,
)
draft_token_ids
,
prev_hidden_states
=
self
.
draft_model_forward
(
num_tokens
=
num_tokens
,
per_layer_attn_metadata
=
per_layer_attn_metadata
,
token_indices_to_sample
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
spec_step_idx
=
token_index
,
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
return
draft_token_ids
.
view
(
-
1
,
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
def
prepare_inputs
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampled_token_ids
:
list
[
list
[
int
]],
num_draft_tokens
:
list
[
int
],
)
->
tuple
[
CommonAttentionMetadata
,
torch
.
Tensor
]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
raise
Exception
(
"speculative_config.disable_padded_drafter_batch"
" is not supported now for MultiLayerEagleProposer."
)
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
num_tokens
:
int
,
use_cudagraphs
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
,
)
->
None
:
cudagraph_runtime_mode
,
num_input_tokens
,
num_tokens_across_dp
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
)
)
# Make sure to use EAGLE's own buffer during cudagraph capture.
if
(
self
.
_draft_attn_layer_names
and
slot_mappings
is
not
None
and
next
(
iter
(
self
.
_draft_attn_layer_names
))
in
slot_mappings
):
slot_mapping_dict
=
self
.
_get_slot_mapping
(
num_input_tokens
)
else
:
slot_mapping_dict
=
slot_mappings
or
{}
adjust_input_kwargs
=
{
"batch_size"
:
1
,
"target_token_ids"
:
self
.
input_ids
[:
num_input_tokens
],
"target_positions"
:
self
.
_get_positions
(
num_input_tokens
),
"target_hidden_states"
:
self
.
hidden_states
[:
num_input_tokens
],
"token_indices_to_sample"
:
torch
.
tensor
(
[
num_input_tokens
-
1
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"common_attn_metadata"
:
CommonAttentionMetadata
(
query_start_loc
=
torch
.
tensor
(
[
0
,
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
query_start_loc_cpu
=
torch
.
tensor
(
[
0
,
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
"cpu"
),
seq_lens
=
torch
.
tensor
(
[
num_input_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
num_reqs
=
1
,
num_actual_tokens
=
num_input_tokens
,
max_query_len
=
num_input_tokens
,
max_seq_len
=
self
.
max_model_len
,
block_table_tensor
=
torch
.
tensor
(
[],
dtype
=
torch
.
int32
,
device
=
self
.
device
),
slot_mapping
=
self
.
arange
[:
num_input_tokens
],
logits_indices_padded
=
None
,
num_logits_indices
=
None
,
causal
=
True
,
encoder_seq_lens
=
None
,
),
"multi_layer_eagle_metadata"
:
MultiLayerEagleMetadata
.
make_dummy
(
layer_num
=
self
.
layer_num
,
hidden_size
=
self
.
hidden_size
,
device
=
self
.
device
,
),
}
# NOTE ensure the jit kernel in _adjust_input can be compiled
self
.
adjust_input
(
**
adjust_input_kwargs
)
for
fwd_idx
in
range
(
self
.
layer_num
):
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
slot_mapping
=
slot_mapping_dict
,
):
if
self
.
supports_mm_inputs
:
input_ids
=
None
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
else
:
input_ids
=
self
.
input_ids
[:
num_input_tokens
]
inputs_embeds
=
self
.
inputs_embeds
[:
num_input_tokens
]
model_kwargs
=
{
"input_ids"
:
input_ids
,
"positions"
:
self
.
_get_positions
(
num_input_tokens
),
"hidden_states"
:
self
.
hidden_states
[:
num_input_tokens
],
"inputs_embeds"
:
inputs_embeds
,
"spec_step_idx"
:
fwd_idx
,
}
self
.
model
(
**
model_kwargs
)
def
_multi_layer_eagle_shift_and_cache
(
*
,
batch_size
:
int
,
max_shift
:
int
,
src_token_ids
:
torch
.
Tensor
,
dst_token_ids
:
torch
.
Tensor
,
src_positions
:
torch
.
Tensor
,
dst_positions
:
torch
.
Tensor
,
src_hidden_states
:
torch
.
Tensor
,
dst_hidden_states
:
torch
.
Tensor
,
src_slot_mapping
:
torch
.
Tensor
,
dst_slot_mapping
:
torch
.
Tensor
,
start_token_indices
:
torch
.
Tensor
,
end_token_indices
:
torch
.
Tensor
,
token_indices_to_sample
:
torch
.
Tensor
,
shift
:
torch
.
Tensor
,
cached_lens
:
torch
.
Tensor
,
cached_prev_token_ids
:
torch
.
Tensor
,
cached_prev_positions
:
torch
.
Tensor
,
cached_prev_hidden_states
:
torch
.
Tensor
,
cached_slot_mappings
:
torch
.
Tensor
,
common_attn_metadata
:
CommonAttentionMetadata
,
):
if
batch_size
==
0
:
return
assert
max_shift
>
0
assert
cached_prev_positions
.
is_contiguous
()
assert
cached_prev_token_ids
.
is_contiguous
()
assert
cached_prev_hidden_states
.
is_contiguous
()
assert
cached_slot_mappings
.
is_contiguous
()
assert
src_hidden_states
.
is_contiguous
()
assert
dst_hidden_states
.
is_contiguous
()
# If src/dst are the same tensor, shifting is unsafe without a separate src.
if
src_slot_mapping
.
data_ptr
()
==
dst_slot_mapping
.
data_ptr
():
src_slot_mapping
=
src_slot_mapping
.
clone
()
# Cache extraction for the next call.
store_start
=
torch
.
maximum
(
start_token_indices
,
(
token_indices_to_sample
+
1
-
max_shift
),
)
store_lens
=
torch
.
clamp
(
token_indices_to_sample
-
store_start
+
1
,
min
=
0
,
max
=
max_shift
,
)
# Avoid device sync: query length == (end - start + 1) == diff of
# query_start_loc (CPU copy).
max_window_len
=
int
(
(
common_attn_metadata
.
query_start_loc_cpu
[
1
:]
-
common_attn_metadata
.
query_start_loc_cpu
[:
-
1
]
)
.
max
()
.
item
()
)
num_blocks
=
max
(
1
,
(
max_window_len
+
BLOCK_TOKENS
-
1
)
//
BLOCK_TOKENS
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_token_ids
,
dst_token_ids
,
cached_prev_token_ids
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_slot_mapping
,
dst_slot_mapping
,
cached_slot_mappings
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
_shift_and_gather_cache_1d_kernel
[(
batch_size
,
num_blocks
)](
src_positions
,
dst_positions
,
cached_prev_positions
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
BLOCK_TOKENS
=
BLOCK_TOKENS
,
)
hidden_size
=
int
(
dst_hidden_states
.
shape
[
1
])
# Hidden blocking avoids extremely large Triton tiles (and huge cubins)
# when hidden_size is large.
num_hidden_blocks
=
max
(
1
,
(
hidden_size
+
BLOCK_HIDDEN
-
1
)
//
BLOCK_HIDDEN
)
_shift_and_gather_hidden_kernel
[(
batch_size
,
num_blocks
,
num_hidden_blocks
)](
src_hidden_states
,
dst_hidden_states
,
cached_prev_hidden_states
,
start_token_indices
,
end_token_indices
,
shift
,
cached_lens
,
store_start
,
store_lens
,
MAX_SHIFT
=
max_shift
,
PADDED_SHIFT
=
triton
.
next_power_of_2
(
max_shift
),
HIDDEN_SIZE
=
hidden_size
,
BLOCK_TOKENS
=
BLOCK_TOKENS
,
BLOCK_HIDDEN
=
BLOCK_HIDDEN
,
num_warps
=
4
,
)
cached_lens
.
copy_
(
store_lens
)
return
@
triton
.
jit
def
_shift_and_gather_cache_1d_kernel
(
src_ptr
,
dst_ptr
,
cached_ptr
,
start_ptr
,
end_ptr
,
shift_ptr
,
cached_len_ptr
,
store_start_ptr
,
store_len_ptr
,
MAX_SHIFT
:
tl
.
constexpr
,
PADDED_SHIFT
:
tl
.
constexpr
,
BLOCK_TOKENS
:
tl
.
constexpr
,
):
# Per-sequence "shift + gather" for packed 1D arrays (token ids, positions,
# slot mappings, ...).
#
# We operate on a packed batch where each sequence (request) occupies a
# contiguous window [start, end] (inclusive) in a flattened tensor.
# For the next speculative step, we build a right-shifted version of each
# window. The shift amount can differ per sequence.
#
# For a single sequence (0-based index i within its window):
# - Prefix (i < shift):
# dst[start + i] = cached[cached_len - shift + i]
# - Body (i >= shift):
# dst[start + i] = src[start + i - shift]
#
# The vacated prefix is filled from a small per-sequence cache (up to
# MAX_SHIFT elements) that stores values from previous speculative steps.
#
# Example:
# cached_tail = [a3, a4]
# src_window = [b0, b1, b2, b3, b4]
# shift = 2
# -> dst_window = [a3, a4, b0, b1, b2]
#
# After dst is produced, we refresh cached_ptr[seq, :] with a suffix of dst
# (specified by store_start / store_len) so the next call can populate its
# prefix from cache.
pid_seq
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
start
=
tl
.
load
(
start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
end
=
tl
.
load
(
end_ptr
+
pid_seq
).
to
(
tl
.
int32
)
shift
=
tl
.
load
(
shift_ptr
+
pid_seq
).
to
(
tl
.
int32
)
cached_len
=
tl
.
load
(
cached_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
assert
cached_len
>=
shift
# get dst indices
base
=
pid_blk
*
BLOCK_TOKENS
k
=
tl
.
arange
(
0
,
BLOCK_TOKENS
)
offs
=
base
+
k
dst_idx
=
start
+
offs
# get dst mask
window_len
=
end
-
start
+
1
mask
=
offs
<
window_len
# load from cached
base_cached
=
cached_ptr
+
pid_seq
*
MAX_SHIFT
cached_idx
=
cached_len
-
shift
+
offs
cached_mask
=
offs
<
shift
val_cached
=
tl
.
load
(
base_cached
+
cached_idx
,
mask
=
mask
&
cached_mask
,
other
=
0
)
# load from src
src_idx
=
start
+
offs
-
shift
val_src
=
tl
.
load
(
src_ptr
+
src_idx
,
mask
=
mask
&
~
cached_mask
,
other
=
0
)
# store to dst
val
=
tl
.
where
(
cached_mask
,
val_cached
,
val_src
)
tl
.
store
(
dst_ptr
+
dst_idx
,
val
,
mask
=
mask
)
# Store into the per-sequence cache.
#
# Cache layout: [batch_size, MAX_SHIFT] (flattened). We always write the
# full MAX_SHIFT region (zero-padded when store_len < MAX_SHIFT) to keep the
# cache contiguous.
store_start
=
tl
.
load
(
store_start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
store_len
=
tl
.
load
(
store_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
m
=
tl
.
arange
(
0
,
PADDED_SHIFT
)
store_mask
=
m
<
MAX_SHIFT
dst_idx
=
store_start
+
m
val
=
tl
.
load
(
dst_ptr
+
dst_idx
,
mask
=
store_mask
&
(
m
<
store_len
),
other
=
0
)
tl
.
store
(
base_cached
+
m
,
val
,
mask
=
store_mask
)
@
triton
.
jit
def
_shift_and_gather_hidden_kernel
(
src_ptr
,
dst_ptr
,
cached_ptr
,
start_ptr
,
end_ptr
,
shift_ptr
,
cached_len_ptr
,
store_start_ptr
,
store_len_ptr
,
MAX_SHIFT
:
tl
.
constexpr
,
PADDED_SHIFT
:
tl
.
constexpr
,
HIDDEN_SIZE
:
tl
.
constexpr
,
BLOCK_TOKENS
:
tl
.
constexpr
,
BLOCK_HIDDEN
:
tl
.
constexpr
,
):
# Per-sequence "shift + gather" for hidden states.
#
# This kernel implements the same logical transformation as
# _shift_and_gather_cache_1d_kernel, but operates on hidden states with
# shape [num_tokens, hidden_size].
#
# Layout:
# - src_ptr / dst_ptr: packed hidden states [num_tokens, hidden_size]
# - cached_ptr: per-sequence cache [batch_size, MAX_SHIFT, hidden_size]
#
# For each sequence window [start, end] (inclusive) and its shift value, for
# 0-based index i within the window:
# - Prefix (i < shift):
# dst[start + i, :] = cached[seq, cached_len - shift + i, :]
# - Body (i >= shift):
# dst[start + i, :] = src[start + i - shift, :]
#
# We tile over tokens (BLOCK_TOKENS) and hidden dim (BLOCK_HIDDEN) to avoid
# extremely large Triton tiles when hidden_size is large. As in the 1D
# kernel, we refresh cached_ptr[seq, :, :] with a suffix of dst so the next
# call can populate its prefix from cache.
pid_seq
=
tl
.
program_id
(
0
)
pid_blk
=
tl
.
program_id
(
1
)
pid_hid
=
tl
.
program_id
(
2
)
start
=
tl
.
load
(
start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
end
=
tl
.
load
(
end_ptr
+
pid_seq
).
to
(
tl
.
int32
)
shift
=
tl
.
load
(
shift_ptr
+
pid_seq
).
to
(
tl
.
int32
)
cached_len
=
tl
.
load
(
cached_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
assert
cached_len
>=
shift
# get dst indices
base
=
pid_blk
*
BLOCK_TOKENS
k
=
tl
.
arange
(
0
,
BLOCK_TOKENS
)
tok_offs
=
base
+
k
dst_tok
=
start
+
tok_offs
n
=
pid_hid
*
BLOCK_HIDDEN
+
tl
.
arange
(
0
,
BLOCK_HIDDEN
)
dst_ptrs
=
dst_ptr
+
dst_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
# get dst mask
window_len
=
end
-
start
+
1
tok_mask
=
tok_offs
<
window_len
n_mask
=
n
<
HIDDEN_SIZE
mask
=
tok_mask
[:,
None
]
&
n_mask
[
None
,
:]
# load from cached
base_cached
=
cached_ptr
+
pid_seq
*
HIDDEN_SIZE
*
MAX_SHIFT
cached_tok
=
cached_len
-
shift
+
tok_offs
cached_ptrs
=
base_cached
+
cached_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
cached_mask
=
tok_offs
<
shift
val_cached
=
tl
.
load
(
cached_ptrs
,
mask
=
mask
&
cached_mask
[:,
None
],
other
=
0
)
# load from src
src_tok
=
start
+
tok_offs
-
shift
src_ptrs
=
src_ptr
+
src_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
val_src
=
tl
.
load
(
src_ptrs
,
mask
=
mask
&
~
cached_mask
[:,
None
],
other
=
0
)
# store to dst
val
=
tl
.
where
(
cached_mask
[:,
None
],
val_cached
,
val_src
)
tl
.
store
(
dst_ptrs
,
val
,
mask
=
mask
)
# store to cached
store_start
=
tl
.
load
(
store_start_ptr
+
pid_seq
).
to
(
tl
.
int32
)
store_len
=
tl
.
load
(
store_len_ptr
+
pid_seq
).
to
(
tl
.
int32
)
m
=
tl
.
arange
(
0
,
PADDED_SHIFT
)
m_mask
=
(
m
<
MAX_SHIFT
)
&
(
m
<
store_len
)
store_tok
=
store_start
+
m
dst_ptrs
=
dst_ptr
+
store_tok
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
store_ptrs
=
base_cached
+
m
[:,
None
]
*
HIDDEN_SIZE
+
n
[
None
,
:]
*
1
mask
=
m_mask
[:,
None
]
&
n_mask
[
None
,
:]
val
=
tl
.
load
(
dst_ptrs
,
mask
=
mask
,
other
=
0
)
tl
.
store
(
store_ptrs
,
val
,
mask
=
mask
)
vllm/v1/worker/gpu_input_batch.py
View file @
84ac1d27
...
...
@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params
:
PoolingParams
|
None
=
None
pooling_states
:
PoolingStates
|
None
=
None
# for multi layer eagle proposer
cached_len
:
torch
.
Tensor
|
None
=
None
cached_token_ids
:
torch
.
Tensor
|
None
=
None
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
cached_positions
:
torch
.
Tensor
|
None
=
None
def
__post_init__
(
self
):
self
.
num_prompt_tokens
=
length_from_prompt_token_ids_or_embeds
(
self
.
prompt_token_ids
,
self
.
prompt_embeds
...
...
@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode
:
bool
=
False
,
is_pooling_model
:
bool
=
False
,
cp_kv_cache_interleave_size
:
int
=
1
,
multi_layer_eagle_num
:
int
=
0
,
hidden_size
:
int
|
None
=
None
,
):
self
.
is_pooling_model
=
is_pooling_model
self
.
is_spec_decode
=
is_spec_decode
...
...
@@ -223,6 +232,46 @@ class InputBatch:
)
self
.
num_accepted_tokens_cpu
=
self
.
num_accepted_tokens_cpu_tensor
.
numpy
()
# Multi layer eagle
self
.
multi_layer_eagle_num
=
multi_layer_eagle_num
if
multi_layer_eagle_num
>
0
:
self
.
cached_len
=
torch
.
zeros
(
(
max_num_reqs
,),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
cached_token_ids
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
cached_hidden_states
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
hidden_size
,
),
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
cached_slot_mappings
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
cached_positions
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,),
dtype
=
np
.
int64
)
self
.
lora_id_to_request_ids
:
dict
[
int
,
set
[
str
]]
=
{}
...
...
@@ -437,6 +486,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self
.
num_accepted_tokens_cpu
[
req_index
]
=
1
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
req_index
]
=
request
.
cached_len
self
.
cached_token_ids
[
req_index
]
=
request
.
cached_token_ids
self
.
cached_hidden_states
[
req_index
]
=
request
.
cached_hidden_states
self
.
cached_slot_mappings
[
req_index
]
=
request
.
cached_slot_mappings
self
.
cached_positions
[
req_index
]
=
request
.
cached_positions
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
...
...
@@ -632,6 +688,20 @@ class InputBatch:
self
.
num_accepted_tokens_cpu
[
i1
],
)
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
i1
],
self
.
cached_len
[
i2
]
=
(
self
.
cached_len
[
i2
],
self
.
cached_len
[
i1
],
)
self
.
cached_token_ids
[[
i1
,
i2
],
...]
=
self
.
cached_token_ids
[[
i2
,
i1
],
...]
self
.
cached_hidden_states
[[
i1
,
i2
],
...]
=
self
.
cached_hidden_states
[
[
i2
,
i1
],
...
]
self
.
cached_slot_mappings
[[
i1
,
i2
],
...]
=
self
.
cached_slot_mappings
[
[
i2
,
i1
],
...
]
self
.
cached_positions
[[
i1
,
i2
],
...]
=
self
.
cached_positions
[[
i2
,
i1
],
...]
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
...
...
@@ -769,6 +839,23 @@ class InputBatch:
if
bad_words_token_ids
is
not
None
:
self
.
bad_words_token_ids
[
empty_index
]
=
bad_words_token_ids
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
empty_index
]
=
self
.
cached_len
[
last_req_index
]
self
.
cached_token_ids
[
empty_index
]
=
self
.
cached_token_ids
[
last_req_index
]
self
.
cached_hidden_states
[
empty_index
]
=
self
.
cached_hidden_states
[
last_req_index
]
self
.
cached_slot_mappings
[
empty_index
]
=
self
.
cached_slot_mappings
[
last_req_index
]
self
.
cached_positions
[
empty_index
]
=
self
.
cached_positions
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
84ac1d27
...
...
@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
,
SpecDecodeMetadata
from
vllm.v1.spec_decode.multi_layer_eagle
import
MultiLayerEagleProposer
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
NgramProposerGPU
,
copy_num_valid_draft_tokens
,
...
...
@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output
:
"SchedulerOutput"
logits
:
torch
.
Tensor
spec_decode_metadata
:
SpecDecodeMetadata
|
None
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
spec_decode_common_attn_metadata
:
CommonAttentionMetadata
|
None
hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
...
...
@@ -500,6 +502,11 @@ class GPUModelRunner(
self
.
late_interaction_runner
=
LateInteractionRunner
()
self
.
use_aux_hidden_state_outputs
=
False
# multi layer eagle
self
.
enable_multi_layer_eagle
=
False
self
.
multi_layer_eagle_num
=
0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
...
...
@@ -544,7 +551,17 @@ class GPUModelRunner(
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
if
(
self
.
speculative_config
.
enable_multi_layers_mtp
and
self
.
speculative_config
.
method
==
"mtp"
):
self
.
enable_multi_layer_eagle
=
True
self
.
drafter
=
MultiLayerEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
multi_layer_eagle_num
=
self
.
drafter
.
layer_num
else
:
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
if
self
.
speculative_config
.
method
==
"eagle3"
:
self
.
use_aux_hidden_state_outputs
=
(
self
.
drafter
.
eagle3_use_aux_hidden_state
...
...
@@ -623,6 +640,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids
=
bool
(
custom_logitsprocs
),
is_pooling_model
=
self
.
is_pooling_model
,
cp_kv_cache_interleave_size
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
...
...
@@ -1143,6 +1164,9 @@ class GPUModelRunner(
if
self
.
uses_xdrope_dim
>
0
:
self
.
_init_xdrope_positions
(
req_state
)
if
self
.
enable_multi_layer_eagle
:
self
.
_init_multi_layer_eagle_cache
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track new requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
...
...
@@ -1442,6 +1466,24 @@ class GPUModelRunner(
req_state
.
mm_features
,
)
def
_init_multi_layer_eagle_cache
(
self
,
req_state
:
CachedRequestState
):
req_state
.
cached_len
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_hidden_states
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
self
.
model_config
.
get_hidden_size
(),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
req_state
.
cached_token_ids
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
req_state
.
cached_positions
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_slot_mappings
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
_extract_mm_kwargs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1672,10 +1714,11 @@ class GPUModelRunner(
)
->
tuple
[
torch
.
Tensor
,
SpecDecodeMetadata
|
None
,
MultiLayerEagleMetadata
|
None
,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
logits_indices, spec_decode_metadata,
multi_layer_eagle_metadata,
]
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
...
...
@@ -1879,9 +1922,21 @@ class GPUModelRunner(
self
.
input_batch
,
num_scheduled_tokens
,
num_sampled_tokens
)
if
self
.
enable_multi_layer_eagle
:
multi_layer_eagle_metadata
=
MultiLayerEagleMetadata
(
cached_len
=
self
.
input_batch
.
cached_len
[:
num_reqs
],
cached_token_ids
=
self
.
input_batch
.
cached_token_ids
[:
num_reqs
],
cached_hidden_states
=
self
.
input_batch
.
cached_hidden_states
[:
num_reqs
],
cached_slot_mappings
=
self
.
input_batch
.
cached_slot_mappings
[:
num_reqs
],
cached_positions
=
self
.
input_batch
.
cached_positions
[:
num_reqs
],
)
else
:
multi_layer_eagle_metadata
=
None
return
(
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
)
def
_build_attention_metadata
(
...
...
@@ -3634,9 +3689,11 @@ class GPUModelRunner(
max_num_scheduled_tokens
=
int
(
num_scheduled_tokens_np
.
max
())
num_tokens_unpadded
=
scheduler_output
.
total_num_scheduled_tokens
logits_indices
,
spec_decode_metadata
=
self
.
_prepare_inputs
(
scheduler_output
,
num_scheduled_tokens_np
,
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
=
(
self
.
_prepare_inputs
(
scheduler_output
,
num_scheduled_tokens_np
,
)
)
cascade_attn_prefix_lens
=
None
...
...
@@ -3867,6 +3924,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3905,6 +3963,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3953,6 +4012,7 @@ class GPUModelRunner(
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
slot_mappings
,
)
...
...
@@ -4242,6 +4302,7 @@ class GPUModelRunner(
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
spec_decode_metadata
:
SpecDecodeMetadata
|
None
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
,
)
->
list
[
list
[
int
]]
|
torch
.
Tensor
:
...
...
@@ -4466,6 +4527,7 @@ class GPUModelRunner(
token_indices_to_sample
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
mm_embed_inputs
=
mm_embed_inputs
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
slot_mappings
=
slot_mappings
,
...
...
@@ -6216,6 +6278,10 @@ class GPUModelRunner(
logitsprocs
=
self
.
input_batch
.
logitsprocs
,
logitsprocs_need_output_token_ids
=
self
.
input_batch
.
logitsprocs_need_output_token_ids
,
is_pooling_model
=
self
.
is_pooling_model
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
assert
self
.
_init_block_sizes
==
block_sizes
,
(
...
...
vllm/version.py
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.18.1"
__version_tuple__
=
(
0
,
18
,
1
)
__hcu_version__
=
f
'0.18.1+das.ori.dtk2604'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
RuntimeWarning
,
stacklevel
=
2
)
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
...
...
@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment