Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
84ac1d27
Commit
84ac1d27
authored
May 09, 2026
by
zhangning3
Browse files
add step3p5 mtp3
parent
b27f1671
Changes
19
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
3486 additions
and
1495 deletions
+3486
-1495
examples/offline_inference/spec_decode.py
examples/offline_inference/spec_decode.py
+14
-1
step3p5_benchmark_test.sh
step3p5_benchmark_test.sh
+242
-0
tests/tool_parsers/test_step3p5_tool_parser.py
tests/tool_parsers/test_step3p5_tool_parser.py
+166
-68
tests/v1/spec_decode/test_mtp3.py
tests/v1/spec_decode/test_mtp3.py
+961
-0
vllm/config/speculative.py
vllm/config/speculative.py
+8
-1
vllm/entrypoints/anthropic/serving.py
vllm/entrypoints/anthropic/serving.py
+122
-31
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+6
-3
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+1
-0
vllm/model_executor/models/step3p5_mtp.py
vllm/model_executor/models/step3p5_mtp.py
+13
-7
vllm/tool_parsers/abstract_tool_parser.py
vllm/tool_parsers/abstract_tool_parser.py
+6
-0
vllm/tool_parsers/step3p5_tool_parser.py
vllm/tool_parsers/step3p5_tool_parser.py
+994
-1351
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+1
-1
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+32
-14
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+2
-1
vllm/v1/spec_decode/metadata.py
vllm/v1/spec_decode/metadata.py
+44
-0
vllm/v1/spec_decode/multi_layer_eagle.py
vllm/v1/spec_decode/multi_layer_eagle.py
+701
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+87
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+72
-6
vllm/version.py
vllm/version.py
+14
-11
No files found.
examples/offline_inference/spec_decode.py
View file @
84ac1d27
...
...
@@ -57,6 +57,11 @@ def parse_args():
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--enable-multi-layers-mtp"
,
action
=
"store_true"
,
help
=
"Enable multi-layer MTP (only effective when --method=mtp)."
,
)
parser
.
add_argument
(
"--max-model-len"
,
type
=
int
,
default
=
16384
)
parser
.
add_argument
(
"--temp"
,
type
=
float
,
default
=
0
)
parser
.
add_argument
(
"--top-p"
,
type
=
float
,
default
=
1.0
)
...
...
@@ -66,12 +71,14 @@ def parse_args():
parser
.
add_argument
(
"--model-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--eagle-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--draft-model"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--tokenizer-dir"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--custom-mm-prompts"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--gpu-memory-utilization"
,
type
=
float
,
default
=
0.9
)
parser
.
add_argument
(
"--disable-padded-drafter-batch"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--max-num-seqs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--parallel-drafting"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--allowed-local-media-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
return
parser
.
parse_args
()
...
...
@@ -85,7 +92,11 @@ def main(args):
"please specify model_dir to give a mm based model"
)
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_dir
)
tokenizer_dir
=
args
.
tokenizer_dir
if
tokenizer_dir
is
None
:
tokenizer_dir
=
model_dir
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_dir
)
if
args
.
custom_mm_prompts
:
prompts
=
llm_prompts
=
get_custom_mm_prompts
(
args
.
num_prompts
)
...
...
@@ -141,6 +152,8 @@ def main(args):
"method"
:
"mtp"
,
"num_speculative_tokens"
:
args
.
num_spec_tokens
,
}
if
args
.
enable_multi_layers_mtp
:
speculative_config
[
"enable_multi_layers_mtp"
]
=
True
else
:
raise
ValueError
(
f
"unknown method:
{
args
.
method
}
"
)
...
...
step3p5_benchmark_test.sh
0 → 100644
View file @
84ac1d27
#!/usr/bin/env bash
set
-u
DEFAULT_BATCH_SIZES
=(
1 8 16 32 64 128
)
MODEL_PATH
=
"/module/step3.5-fp8/"
SERVED_MODEL_NAME
=
"/module/step3.5-fp8/"
DATASET_NAME
=
"random"
DEFAULT_OUTPUT_LEN_DECODE
=
4096
DEFAULT_OUTPUT_LEN_PREFILL
=
1
DEFAULT_ROLE
=
"decode"
READY_CHECK_TIMEOUT
=
3
RESULT_DIR
=
"benchmark_result"
print_usage
()
{
cat
<<
'
USAGE
'
Usage:
./scripts/step3p5_benchmark_test.sh
./scripts/step3p5_benchmark_test.sh 1,8,16,32
./scripts/step3p5_benchmark_test.sh --role prefill
./scripts/step3p5_benchmark_test.sh --role decode
./scripts/step3p5_benchmark_test.sh --role both
./scripts/step3p5_benchmark_test.sh --role prefill --output-len 1
Description:
- No argument: use default batch sizes, role=decode, output-len=4096
- Optional positional argument: batch size list (comma or space separated)
- Optional flag: --role <prefill|decode|both>
- Optional flag: --output-len <N> (must be positive integer)
- role=both 时串行执行 prefill 再 decode
- Result files are saved under:
<result_dir>/prefill (when role=prefill)
<result_dir>/decode (when role=decode)
USAGE
}
parse_batch_sizes
()
{
local
raw_input
=
"
${
1
:-}
"
if
[[
-z
"
$raw_input
"
]]
;
then
BATCH_SIZES
=(
"
${
DEFAULT_BATCH_SIZES
[@]
}
"
)
return
fi
raw_input
=
"
${
raw_input
//,/
}
"
read
-r
-a
BATCH_SIZES
<<<
"
$raw_input
"
if
((
${#
BATCH_SIZES
[@]
}
==
0
))
;
then
echo
"[ERROR] batch size 列表为空。"
print_usage
exit
1
fi
local
batch_size
for
batch_size
in
"
${
BATCH_SIZES
[@]
}
"
;
do
if
!
[[
"
$batch_size
"
=
~ ^[1-9][0-9]
*
$
]]
;
then
echo
"[ERROR] 非法 batch size:
$batch_size
(必须是正整数)"
exit
1
fi
done
}
parse_role
()
{
local
role_input
=
"
${
1
:-
$DEFAULT_ROLE
}
"
if
[[
"
$role_input
"
!=
"prefill"
&&
"
$role_input
"
!=
"decode"
&&
"
$role_input
"
!=
"both"
]]
;
then
echo
"[ERROR] 非法 role:
$role_input
(必须是 prefill、decode 或 both)"
print_usage
exit
1
fi
ROLE
=
"
$role_input
"
}
parse_output_len
()
{
local
output_len_input
=
"
$1
"
if
!
[[
"
$output_len_input
"
=
~ ^[1-9][0-9]
*
$
]]
;
then
echo
"[ERROR] 非法 output-len:
$output_len_input
(必须是正整数)"
print_usage
exit
1
fi
RANDOM_OUTPUT_LEN
=
"
$output_len_input
"
}
calc_num_prompts
()
{
local
batch_size
=
"
$1
"
local
num_prompts
=
$((
batch_size
+
batch_size
/
2
))
if
((
num_prompts < 16
))
;
then
num_prompts
=
16
fi
if
((
num_prompts
>
384
))
;
then
num_prompts
=
384
fi
if
((
num_prompts < batch_size
))
;
then
num_prompts
=
$batch_size
fi
echo
"
$num_prompts
"
}
main
()
{
local
batch_arg
=
""
local
role_arg
=
"
$DEFAULT_ROLE
"
local
output_len_arg
=
""
while
((
$#
>
0
))
;
do
case
"
$1
"
in
-h
|
--help
)
print_usage
exit
0
;;
--role
|
-r
)
if
[[
-z
"
${
2
:-}
"
]]
;
then
echo
"[ERROR] --role 缺少参数。"
print_usage
exit
1
fi
role_arg
=
"
$2
"
shift
2
;;
--output-len
|
-o
)
if
[[
-z
"
${
2
:-}
"
]]
;
then
echo
"[ERROR] --output-len 缺少参数。"
print_usage
exit
1
fi
output_len_arg
=
"
$2
"
shift
2
;;
--
*
)
echo
"[ERROR] 未知参数:
$1
"
print_usage
exit
1
;;
*
)
if
[[
-n
"
$batch_arg
"
]]
;
then
echo
"[ERROR] 仅支持一个 batch size 列表参数。"
print_usage
exit
1
fi
batch_arg
=
"
$1
"
shift
;;
esac
done
parse_batch_sizes
"
$batch_arg
"
parse_role
"
$role_arg
"
if
[[
"
$ROLE
"
==
"both"
]]
;
then
local
-a
prefill_cmd
=(
"
$0
"
)
local
-a
decode_cmd
=(
"
$0
"
)
if
[[
-n
"
$batch_arg
"
]]
;
then
prefill_cmd+
=(
"
$batch_arg
"
)
decode_cmd+
=(
"
$batch_arg
"
)
fi
prefill_cmd+
=(
"--role"
"prefill"
)
decode_cmd+
=(
"--role"
"decode"
)
if
[[
-n
"
$output_len_arg
"
]]
;
then
decode_cmd+
=(
"--output-len"
"
$output_len_arg
"
)
fi
echo
"[INFO] role=both: 将串行执行 prefill 和 decode"
echo
"[INFO] step1:
${
prefill_cmd
[*]
}
"
"
${
prefill_cmd
[@]
}
"
echo
"[INFO] step2:
${
decode_cmd
[*]
}
"
"
${
decode_cmd
[@]
}
"
echo
"[INFO] role=both 执行完成。"
return
0
fi
if
[[
"
$ROLE
"
==
"prefill"
]]
;
then
if
[[
-n
"
$output_len_arg
"
&&
"
$output_len_arg
"
!=
"
$DEFAULT_OUTPUT_LEN_PREFILL
"
]]
;
then
echo
"[WARN] role=prefill 时 output-len 必须为 1,已自动覆盖为 1。"
fi
output_len_arg
=
"
$DEFAULT_OUTPUT_LEN_PREFILL
"
elif
[[
-z
"
$output_len_arg
"
]]
;
then
output_len_arg
=
"
$DEFAULT_OUTPUT_LEN_DECODE
"
fi
parse_output_len
"
$output_len_arg
"
RESULT_SUBDIR
=
"
$RESULT_DIR
/
$ROLE
"
mkdir
-p
"
$RESULT_SUBDIR
"
echo
"[INFO] 将执行
${#
BATCH_SIZES
[@]
}
组 benchmark"
echo
"[INFO] role:
$ROLE
"
echo
"[INFO] random-output-len:
$RANDOM_OUTPUT_LEN
"
echo
"[INFO] batch size 列表:
${
BATCH_SIZES
[*]
}
"
echo
"[INFO] result_dir:
$RESULT_SUBDIR
"
local
batch_size
local
num_prompts
local
failed_count
=
0
for
batch_size
in
"
${
BATCH_SIZES
[@]
}
"
;
do
num_prompts
=
"
$(
calc_num_prompts
"
$batch_size
"
)
"
echo
""
echo
"[INFO] 开始测试: role=
$ROLE
, batch_size=
$batch_size
, max_concurrency=
$batch_size
, num_prompts=
$num_prompts
, random_output_len=
$RANDOM_OUTPUT_LEN
"
if
!
vllm bench serve
\
--backend
vllm
\
--model
"
$MODEL_PATH
"
\
--served-model-name
"
$SERVED_MODEL_NAME
"
\
--dataset-name
"
$DATASET_NAME
"
\
--random-input-len
65536
\
--random-output-len
"
$RANDOM_OUTPUT_LEN
"
\
--num-prompts
"
$num_prompts
"
\
--temperature
0
\
--max-concurrency
"
$batch_size
"
\
--ready-check-timeout
"
$READY_CHECK_TIMEOUT
"
\
--result-dir
"
$RESULT_SUBDIR
"
\
--port
8018
\
--save-result
;
then
echo
"[WARN] batch_size=
$batch_size
执行失败,继续下一个。"
failed_count
=
$((
failed_count
+
1
))
else
echo
"[INFO] batch_size=
$batch_size
执行完成。"
fi
done
echo
""
if
((
failed_count
>
0
))
;
then
echo
"[WARN] 全部执行结束,但有
$failed_count
组失败。"
exit
1
fi
echo
"[INFO] 全部 benchmark 执行完成。"
}
main
"
$@
"
tests/tool_parsers/test_step3p5_tool_parser.py
View file @
84ac1d27
...
...
@@ -386,6 +386,57 @@ TX
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
def
test_extract_tool_calls_missing_function_name
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call without function name should fallback to content."""
model_output
=
(
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_empty_function_name
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call with empty function name should fallback to content."""
model_output
=
(
"<tool_call><function=><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_empty_function_name_single_quotes
(
step3p5_tool_parser
,
sample_tools
):
"""Tool call with empty function name (single quotes) should fallback."""
model_output
=
(
"<tool_call><function=''><parameter=pattern>*.py</parameter></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
extracted_tool_calls
=
step3p5_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
request
)
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
def
test_extract_tool_calls_type_conversion
(
step3p5_tool_parser
):
"""Test parameter type conversion based on tool schema"""
tools
=
[
...
...
@@ -623,7 +674,7 @@ def test_extract_tool_calls_streaming(
expected_tool_calls
,
expected_content
,
):
"""Test
incremental streaming behavior including typed parameters
"""
"""Test
streaming returns complete tool calls when parsed.
"""
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
""
...
...
@@ -647,11 +698,10 @@ def test_extract_tool_calls_streaming(
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
# First chunk should have id, name, and type
if
tool_call
.
id
:
tool_states
[
idx
][
"id"
]
=
tool_call
.
id
...
...
@@ -666,8 +716,15 @@ def test_extract_tool_calls_streaming(
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
# Accumulate arguments incrementally
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
# Arguments should be complete JSON when emitted.
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Verify final content
assert
other_content
==
(
expected_content
or
""
)
# Handle None case
...
...
@@ -682,7 +739,7 @@ def test_extract_tool_calls_streaming(
assert
state
[
"type"
]
==
"function"
assert
state
[
"name"
]
==
expected_tool
.
function
.
name
# Parse
accumulated
arguments
# Parse arguments
arguments_str
=
state
[
"arguments"
]
assert
arguments_str
is
not
None
actual_args
=
json
.
loads
(
arguments_str
)
...
...
@@ -770,7 +827,7 @@ fahrenheit
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -786,7 +843,14 @@ fahrenheit
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Verify content was streamed
assert
"Let me check the weather for you:"
in
other_content
...
...
@@ -806,62 +870,69 @@ fahrenheit
assert
args
[
"unit"
]
==
"fahrenheit"
def
test_extract_tool_calls_streaming_
incremental
(
def
test_extract_tool_calls_streaming_
missing_function_name
(
step3p5_tool_parser
,
step3p5_tokenizer
,
sample_tools
):
"""Test that streaming is truly incremental"""
model_output
=
"""I'll check the weather.<tool_call>
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
</function>
</tool_call>"""
"""Streaming: missing function name should be treated as content."""
model_output
=
(
"<tool_call><parameter=pattern>*.py</parameter></function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
chunks
=
[]
for
delta_message
in
stream_delta_message_generator
(
step3p5_tool_parser
,
step3p5_tokenizer
,
model_output
,
request
other_content
=
""
tool_calls
=
[]
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
[
"<tool_call><parameter=pattern>"
,
"*.py</parameter>"
,
"</function></tool_call>"
,
],
request
,
):
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
tool_calls
.
extend
(
delta_message
.
tool_calls
)
assert
other_content
==
model_output
assert
tool_calls
==
[]
def
test_extract_tool_calls_streaming_empty_function_name
(
step3p5_tool_parser
,
step3p5_tokenizer
,
sample_tools
):
"""Streaming: empty function name should be treated as content."""
model_output
=
(
"<tool_call><function=''><parameter=pattern>*.py</parameter>"
"</function></tool_call>"
)
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
sample_tools
)
other_content
=
""
tool_calls
=
[]
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
[
"<tool_call><function="
,
"''><parameter=pattern>*.py</parameter>"
,
"</function></tool_call>"
,
],
request
,
):
chunks
.
append
(
delta_message
)
# Should have multiple chunks
assert
len
(
chunks
)
>
3
# First chunk(s) should be content
assert
chunks
[
0
].
content
is
not
None
assert
chunks
[
0
].
tool_calls
is
None
or
chunks
[
0
].
tool_calls
==
[]
# Should have a chunk with tool header (id, name, type)
header_found
=
False
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
id
:
header_found
=
True
assert
chunk
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
assert
chunk
.
tool_calls
[
0
].
type
==
"function"
# Empty initially
assert
chunk
.
tool_calls
[
0
].
function
.
arguments
==
""
break
assert
header_found
# Should have chunks with incremental arguments
arg_chunks
=
[]
for
chunk
in
chunks
:
if
chunk
.
tool_calls
and
chunk
.
tool_calls
[
0
].
function
.
arguments
:
arg_chunks
.
append
(
chunk
.
tool_calls
[
0
].
function
.
arguments
)
# Arguments should be streamed incrementally
assert
len
(
arg_chunks
)
>
1
# Concatenated arguments should form valid JSON
full_args
=
""
.
join
(
arg_chunks
)
parsed_args
=
json
.
loads
(
full_args
)
assert
parsed_args
[
"city"
]
==
"Dallas"
assert
parsed_args
[
"state"
]
==
"TX"
if
delta_message
.
content
:
other_content
+=
delta_message
.
content
if
delta_message
.
tool_calls
:
tool_calls
.
extend
(
delta_message
.
tool_calls
)
assert
other_content
==
model_output
assert
tool_calls
==
[]
def
test_extract_tool_calls_complex_type_with_single_quote
(
step3p5_tool_parser
):
...
...
@@ -951,7 +1022,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -967,7 +1038,14 @@ rectangle
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1164,7 +1242,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
if
tool_call
.
id
:
...
...
@@ -1175,7 +1253,14 @@ rectangle
if
tool_call
.
function
.
name
:
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1266,7 +1351,7 @@ rectangle
tool_states
[
idx
]
=
{
"id"
:
None
,
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
"type"
:
None
,
}
...
...
@@ -1282,7 +1367,14 @@ rectangle
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
# Should have exactly two complete tool calls
assert
len
(
tool_states
)
==
2
,
"Should have exactly two complete tool calls"
...
...
@@ -1344,20 +1436,26 @@ rectangle""",
for
delta_message
in
stream_delta_message_generator_from_chunks
(
step3p5_tool_parser
,
step3p5_tokenizer
,
delta_text_chunks
,
request
):
print
(
delta_message
)
if
delta_message
.
tool_calls
:
for
tool_call
in
delta_message
.
tool_calls
:
idx
=
tool_call
.
index
if
idx
not
in
tool_states
:
tool_states
[
idx
]
=
{
"name"
:
None
,
"arguments"
:
""
,
"arguments"
:
None
,
}
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
tool_states
[
idx
][
"name"
]
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
is
not
None
:
tool_states
[
idx
][
"arguments"
]
+=
tool_call
.
function
.
arguments
json
.
loads
(
tool_call
.
function
.
arguments
)
if
tool_states
[
idx
][
"arguments"
]
is
None
:
tool_states
[
idx
][
"arguments"
]
=
tool_call
.
function
.
arguments
else
:
assert
(
tool_states
[
idx
][
"arguments"
]
==
tool_call
.
function
.
arguments
)
assert
len
(
tool_states
)
==
2
assert
all
(
state
[
"name"
]
for
state
in
tool_states
.
values
())
...
...
@@ -1368,7 +1466,7 @@ rectangle""",
def
test_extract_tool_calls_non_streaming_multiple_tool_calls_no_content_between
(
step3p5_tool_parser
,
sample_tools
):
"""Test non-streaming extraction with tool calls
and no content between them
.
"""Test non-streaming extraction with
multiple
tool calls.
Scenario: Model outputs "hello" + tool call + tool call.
Expected: "hello" as content, first tool call parsed (index=0),
...
...
tests/v1/spec_decode/test_mtp3.py
0 → 100644
View file @
84ac1d27
This diff is collapsed.
Click to expand it.
vllm/config/speculative.py
View file @
84ac1d27
...
...
@@ -80,6 +80,10 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp
:
bool
=
False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size
:
int
|
None
=
Field
(
default
=
None
,
ge
=
1
)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
...
...
@@ -493,7 +497,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self
.
method
=
"mtp"
if
self
.
num_speculative_tokens
>
1
:
if
(
self
.
enable_multi_layers_mtp
is
False
and
self
.
num_speculative_tokens
>
1
):
logger
.
warning
(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
...
...
vllm/entrypoints/anthropic/serving.py
View file @
84ac1d27
...
...
@@ -166,7 +166,72 @@ class AnthropicServingMessages(OpenAIServingChat):
if
isinstance
(
msg
.
content
,
str
):
openai_msg
[
"content"
]
=
msg
.
content
else
:
cls
.
_convert_message_content
(
msg
,
openai_msg
,
openai_messages
)
# Handle complex content blocks
content_parts
:
list
[
dict
[
str
,
Any
]]
=
[]
tool_calls
:
list
[
dict
[
str
,
Any
]]
=
[]
reasoning_parts
:
list
[
str
]
=
[]
for
block
in
msg
.
content
:
if
block
.
type
==
"text"
and
block
.
text
:
content_parts
.
append
({
"type"
:
"text"
,
"text"
:
block
.
text
})
elif
block
.
type
==
"image"
and
block
.
source
:
content_parts
.
append
(
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
block
.
source
.
get
(
"data"
,
""
)},
}
)
elif
block
.
type
==
"thinking"
and
block
.
thinking
is
not
None
:
reasoning_parts
.
append
(
block
.
thinking
)
elif
block
.
type
==
"tool_use"
:
# Convert tool use to function call format
tool_call
=
{
"id"
:
block
.
id
or
f
"call_
{
int
(
time
.
time
())
}
"
,
"type"
:
"function"
,
"function"
:
{
"name"
:
block
.
name
or
""
,
"arguments"
:
json
.
dumps
(
block
.
input
or
{}),
},
}
tool_calls
.
append
(
tool_call
)
elif
block
.
type
==
"tool_result"
:
if
msg
.
role
==
"user"
:
openai_messages
.
append
(
{
"role"
:
"tool"
,
"tool_call_id"
:
block
.
id
or
""
,
"content"
:
str
(
block
.
content
)
if
block
.
content
else
""
,
}
)
else
:
# Assistant tool result becomes regular text
tool_result_text
=
(
str
(
block
.
content
)
if
block
.
content
else
""
)
content_parts
.
append
(
{
"type"
:
"text"
,
"text"
:
f
"Tool result:
{
tool_result_text
}
"
,
}
)
if
reasoning_parts
:
openai_msg
[
"reasoning"
]
=
""
.
join
(
reasoning_parts
)
# Add tool calls to the message if any
if
tool_calls
:
openai_msg
[
"tool_calls"
]
=
tool_calls
# type: ignore
# Add content parts if any
if
content_parts
:
if
len
(
content_parts
)
==
1
and
content_parts
[
0
][
"type"
]
==
"text"
:
openai_msg
[
"content"
]
=
content_parts
[
0
][
"text"
]
else
:
openai_msg
[
"content"
]
=
content_parts
# type: ignore
elif
not
tool_calls
and
not
reasoning_parts
:
continue
openai_messages
.
append
(
openai_msg
)
...
...
@@ -522,49 +587,75 @@ class AnthropicServingMessages(OpenAIServingChat):
first_item
=
True
finish_reason
=
None
state
=
_ActiveBlockState
()
content_block_index
=
0
active_block_type
:
str
|
None
=
None
active_block_index
:
int
|
None
=
None
active_block_signature
:
str
|
None
=
None
signature_emitted
=
False
active_tool_use_id
:
str
|
None
=
None
# Map from tool call index to tool_use_id
tool_index_to_id
:
dict
[
int
,
str
]
=
{}
def
stop_active_block
():
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
events
:
list
[
str
]
=
[]
if
state
.
block_type
is
None
:
if
active_
block_type
is
None
:
return
events
if
(
state
.
block_type
==
"thinking"
and
state
.
block_signature
is
not
None
and
not
state
.
signature_emitted
active_
block_type
==
"thinking"
and
active_
block_signature
is
not
None
and
not
signature_emitted
):
chunk
=
AnthropicStreamEvent
(
index
=
state
.
block_index
,
index
=
active_
block_index
,
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
type
=
"signature_delta"
,
signature
=
state
.
block_signature
,
signature
=
active_
block_signature
,
),
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_delta"
))
state
.
signature_emitted
=
True
signature_emitted
=
True
stop_chunk
=
AnthropicStreamEvent
(
index
=
state
.
block_index
,
index
=
active_
block_index
,
type
=
"content_block_stop"
,
)
data
=
stop_chunk
.
model_dump_json
(
exclude_unset
=
True
)
events
.
append
(
wrap_data_with_event
(
data
,
"content_block_stop"
))
state
.
reset
()
state
.
content_block_index
+=
1
active_block_type
=
None
active_block_index
=
None
active_block_signature
=
None
signature_emitted
=
False
active_tool_use_id
=
None
content_block_index
+=
1
return
events
def
start_block
(
block
:
AnthropicContentBlock
):
nonlocal
active_block_type
,
active_block_index
,
content_block_index
nonlocal
active_block_signature
,
signature_emitted
,
active_tool_use_id
chunk
=
AnthropicStreamEvent
(
index
=
state
.
content_block_index
,
index
=
content_block_index
,
type
=
"content_block_start"
,
content_block
=
block
,
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
event
=
wrap_data_with_event
(
data
,
"content_block_start"
)
state
.
start
(
block
)
active_block_type
=
block
.
type
active_block_index
=
content_block_index
if
block
.
type
==
"thinking"
:
active_block_signature
=
uuid
.
uuid4
().
hex
signature_emitted
=
False
active_tool_use_id
=
None
elif
block
.
type
==
"tool_use"
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
block
.
id
else
:
active_block_signature
=
None
signature_emitted
=
True
active_tool_use_id
=
None
return
event
async
for
item
in
generator
:
...
...
@@ -638,7 +729,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if
reasoning_delta
==
""
:
pass
else
:
if
state
.
block_type
!=
"thinking"
:
if
active_
block_type
!=
"thinking"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
...
...
@@ -649,9 +740,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -666,7 +757,7 @@ class AnthropicServingMessages(OpenAIServingChat):
if
origin_chunk
.
choices
[
0
].
delta
.
content
==
""
:
pass
else
:
if
state
.
block_type
!=
"text"
:
if
active_
block_type
!=
"text"
:
for
event
in
stop_active_block
():
yield
event
start_event
=
start_block
(
...
...
@@ -675,9 +766,9 @@ class AnthropicServingMessages(OpenAIServingChat):
yield
start_event
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -702,7 +793,7 @@ class AnthropicServingMessages(OpenAIServingChat):
else
None
)
if
(
state
.
tool_use_id
!=
tool_call
.
id
active_
tool_use_id
!=
tool_call
.
id
and
tool_name
is
not
None
):
for
event
in
stop_active_block
():
...
...
@@ -720,13 +811,13 @@ class AnthropicServingMessages(OpenAIServingChat):
if
(
tool_call
.
function
and
tool_call
.
function
.
arguments
and
state
.
tool_use_id
==
tool_call
.
id
and
active_
tool_use_id
==
tool_call
.
id
):
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
@@ -745,13 +836,13 @@ class AnthropicServingMessages(OpenAIServingChat):
tool_use_id
is
not
None
and
tool_call
.
function
and
tool_call
.
function
.
arguments
and
state
.
tool_use_id
==
tool_use_id
and
active_
tool_use_id
==
tool_use_id
):
chunk
=
AnthropicStreamEvent
(
index
=
(
state
.
block_index
if
state
.
block_index
is
not
None
else
state
.
content_block_index
active_
block_index
if
active_
block_index
is
not
None
else
content_block_index
),
type
=
"content_block_delta"
,
delta
=
AnthropicDelta
(
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
84ac1d27
...
...
@@ -1101,10 +1101,10 @@ class OpenAIServingChat(OpenAIServing):
index
=
0
if
(
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
tool_parser
and
self
.
_should_check_for_unstreamed_tool_arg_tokens
(
delta_message
,
output
,
tool_parser
)
and
tool_parser
):
latest_delta_len
=
0
if
(
...
...
@@ -1760,6 +1760,7 @@ class OpenAIServingChat(OpenAIServing):
self
,
delta_message
:
DeltaMessage
|
None
,
output
:
CompletionOutput
,
tool_parser
:
ToolParser
|
None
=
None
,
)
->
bool
:
"""
Check to see if we should check for unstreamed tool arguments tokens.
...
...
@@ -1772,6 +1773,8 @@ class OpenAIServingChat(OpenAIServing):
# include a function that has arguments
output
.
finish_reason
is
not
None
and
self
.
enable_auto_tools
and
tool_parser
is
not
None
and
tool_parser
.
parser_should_check_for_unstreamed_tool_arg_tokens
()
and
self
.
tool_parser
and
delta_message
and
delta_message
.
tool_calls
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
84ac1d27
...
...
@@ -262,6 +262,7 @@ def select_fp8_moe_backend(
supported
,
reason
=
k_cls
.
is_supported_config
(
k_cls
,
config
,
weight_key
,
activation_key
,
activation_format
)
supported
=
True
if
supported
:
logger
.
info_once
(
_make_log_backend
(
backend
),
scope
=
"local"
)
return
backend
,
k_cls
...
...
vllm/model_executor/models/step3p5_mtp.py
View file @
84ac1d27
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
...
...
@@ -40,9 +41,11 @@ class SharedHead(nn.Module):
return
self
.
norm
(
hidden_states
)
@
support_torch_compile
class
Step3p5AMultiTokenPredictorLayer
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
,
)
->
None
:
...
...
@@ -52,7 +55,7 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
self
.
enorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
hnorm
=
GemmaRMSNorm
(
config
.
hidden_size
,
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared
_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
lm
_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
Step3p5DecoderLayer
(
vllm_config
,
prefix
=
f
"
{
prefix
}
.mtp_block"
,
...
...
@@ -64,9 +67,12 @@ class Step3p5AMultiTokenPredictorLayer(nn.Module):
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
embed_tokens
:
VocabParallelEmbedding
|
None
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
if
inputs_embeds
is
None
:
assert
embed_tokens
is
not
None
inputs_embeds
=
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
...
...
@@ -92,8 +98,8 @@ class Step3p5AMultiTokenPredictor(nn.Module):
self
.
layers
=
torch
.
nn
.
ModuleDict
(
{
str
(
idx
):
Step3p5AMultiTokenPredictorLayer
(
vllm_config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
...
...
@@ -112,14 +118,13 @@ class Step3p5AMultiTokenPredictor(nn.Module):
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
self
.
embed_tokens
,
current_step_idx
,
)
...
...
@@ -131,7 +136,7 @@ class Step3p5AMultiTokenPredictor(nn.Module):
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared
_head
.
head
,
mtp_layer
.
shared
_head
(
hidden_states
)
mtp_layer
.
lm
_head
.
head
,
mtp_layer
.
lm
_head
(
hidden_states
)
)
return
logits
...
...
@@ -257,6 +262,7 @@ class Step3p5MTP(nn.Module):
name
=
name
.
replace
(
".transformer."
,
"."
)
if
"shared_head"
in
name
:
name
=
name
.
replace
(
"shared_head.output"
,
"shared_head.head"
)
name
=
name
.
replace
(
"shared_head"
,
"lm_head"
)
if
"embed_tokens"
in
name
:
assert
(
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
)
...
...
vllm/tool_parsers/abstract_tool_parser.py
View file @
84ac1d27
...
...
@@ -118,6 +118,12 @@ class ToolParser:
"AbstractToolParser.extract_tool_calls_streaming has not been implemented!"
)
def
parser_should_check_for_unstreamed_tool_arg_tokens
(
self
)
->
bool
:
"""
Whether to check for unstreamed tool-argument tokens in serving
"""
return
True
class
ToolParserManager
:
"""
...
...
vllm/tool_parsers/step3p5_tool_parser.py
View file @
84ac1d27
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/triton_attn.py
View file @
84ac1d27
...
...
@@ -514,7 +514,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale
=
None
,
# Not supported
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
seq_threshold_3D
=
seq_threshold_3D
,
seq_threshold_3D
=
None
,
num_par_softmax_segments
=
num_par_softmax_segments
,
softmax_segm_output
=
softmax_segm_output
,
softmax_segm_max
=
softmax_segm_max
,
...
...
vllm/v1/core/kv_cache_utils.py
View file @
84ac1d27
...
...
@@ -957,6 +957,7 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
def
_get_kv_cache_groups_uniform_page_size
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
)
->
list
[
KVCacheGroupSpec
]:
"""
...
...
@@ -976,6 +977,12 @@ def _get_kv_cache_groups_uniform_page_size(
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The generated KVCacheGroupSpecs
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
...
...
@@ -1062,19 +1069,28 @@ def _get_kv_cache_groups_uniform_page_size(
num_padding_layers
/
len
(
layers
)
*
100
,
)
num_groups
=
cdiv
(
len
(
layers
),
group_size
)
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for
i
in
range
(
num_groups
):
grouped_layers
.
append
(
layers
[
i
::
num_groups
])
# for support multi layer mtp, we need to
# make all mtp layers in the same group
if
(
vllm_config
.
speculative_config
is
not
None
and
vllm_config
.
speculative_config
.
enable_multi_layers_mtp
):
for
i
in
range
(
0
,
len
(
layers
),
group_size
):
grouped_layers
.
append
(
layers
[
i
:
i
+
group_size
])
else
:
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for
i
in
range
(
num_groups
):
grouped_layers
.
append
(
layers
[
i
::
num_groups
])
return
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layers
)
...
...
@@ -1259,7 +1275,9 @@ def get_kv_cache_groups(
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return
_get_kv_cache_groups_uniform_page_size
(
kv_cache_spec
)
return
_get_kv_cache_groups_uniform_page_size
(
vllm_config
=
vllm_config
,
kv_cache_spec
=
kv_cache_spec
)
def
generate_scheduler_kv_cache_config
(
...
...
vllm/v1/spec_decode/eagle.py
View file @
84ac1d27
...
...
@@ -38,7 +38,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
UniformTypeKVCacheSpecs
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.sampler
import
_SAMPLING_EPS
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
,
SpecDecodeMetadata
from
vllm.v1.spec_decode.utils
import
(
PADDING_SLOT_ID
,
compute_new_slot_mapping
,
...
...
@@ -395,6 +395,7 @@ class SpecDecodeBaseProposer:
token_indices_to_sample
:
torch
.
Tensor
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
sampling_metadata
:
SamplingMetadata
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
=
None
,
mm_embed_inputs
:
tuple
[
list
[
torch
.
Tensor
],
torch
.
Tensor
]
|
None
=
None
,
num_rejected_tokens_gpu
:
torch
.
Tensor
|
None
=
None
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
...
...
vllm/v1/spec_decode/metadata.py
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
import
numpy
as
np
...
...
@@ -64,3 +66,45 @@ class SpecDecodeMetadata:
bonus_logits_indices
=
bonus_logits_indices
,
logits_indices
=
logits_indices
,
)
@
dataclass
class
MultiLayerEagleMetadata
:
# [batch_size]
cached_len
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_token_ids
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num, hidden_size]
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
# [batch_size, layer_num]
cached_positions
:
torch
.
Tensor
|
None
=
None
@
classmethod
def
make_dummy
(
cls
,
layer_num
:
int
,
hidden_size
:
int
,
device
:
torch
.
device
,
)
->
"MultiLayerEagleMetadata"
:
cached_len
=
torch
.
zeros
((
1
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_token_ids
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int32
,
device
=
device
)
cached_hidden_states
=
torch
.
zeros
(
(
1
,
layer_num
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
device
)
cached_slot_mappings
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
cached_positions
=
torch
.
zeros
(
(
1
,
layer_num
),
dtype
=
torch
.
int64
,
device
=
device
)
return
cls
(
cached_len
=
cached_len
,
cached_token_ids
=
cached_token_ids
,
cached_hidden_states
=
cached_hidden_states
,
cached_slot_mappings
=
cached_slot_mappings
,
cached_positions
=
cached_positions
,
)
vllm/v1/spec_decode/multi_layer_eagle.py
0 → 100644
View file @
84ac1d27
This diff is collapsed.
Click to expand it.
vllm/v1/worker/gpu_input_batch.py
View file @
84ac1d27
...
...
@@ -53,6 +53,13 @@ class CachedRequestState:
pooling_params
:
PoolingParams
|
None
=
None
pooling_states
:
PoolingStates
|
None
=
None
# for multi layer eagle proposer
cached_len
:
torch
.
Tensor
|
None
=
None
cached_token_ids
:
torch
.
Tensor
|
None
=
None
cached_hidden_states
:
torch
.
Tensor
|
None
=
None
cached_slot_mappings
:
torch
.
Tensor
|
None
=
None
cached_positions
:
torch
.
Tensor
|
None
=
None
def
__post_init__
(
self
):
self
.
num_prompt_tokens
=
length_from_prompt_token_ids_or_embeds
(
self
.
prompt_token_ids
,
self
.
prompt_embeds
...
...
@@ -95,6 +102,8 @@ class InputBatch:
is_spec_decode
:
bool
=
False
,
is_pooling_model
:
bool
=
False
,
cp_kv_cache_interleave_size
:
int
=
1
,
multi_layer_eagle_num
:
int
=
0
,
hidden_size
:
int
|
None
=
None
,
):
self
.
is_pooling_model
=
is_pooling_model
self
.
is_spec_decode
=
is_spec_decode
...
...
@@ -223,6 +232,46 @@ class InputBatch:
)
self
.
num_accepted_tokens_cpu
=
self
.
num_accepted_tokens_cpu_tensor
.
numpy
()
# Multi layer eagle
self
.
multi_layer_eagle_num
=
multi_layer_eagle_num
if
multi_layer_eagle_num
>
0
:
self
.
cached_len
=
torch
.
zeros
(
(
max_num_reqs
,),
dtype
=
torch
.
int64
,
device
=
device
)
self
.
cached_token_ids
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int32
,
device
=
device
,
)
self
.
cached_hidden_states
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
hidden_size
,
),
dtype
=
torch
.
float
,
device
=
device
,
)
self
.
cached_slot_mappings
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
self
.
cached_positions
=
torch
.
zeros
(
(
max_num_reqs
,
multi_layer_eagle_num
,
),
dtype
=
torch
.
int64
,
device
=
device
,
)
# lora related
self
.
request_lora_mapping
=
np
.
zeros
((
self
.
max_num_reqs
,),
dtype
=
np
.
int64
)
self
.
lora_id_to_request_ids
:
dict
[
int
,
set
[
str
]]
=
{}
...
...
@@ -437,6 +486,13 @@ class InputBatch:
# Speculative decoding: by default 1 token is generated.
self
.
num_accepted_tokens_cpu
[
req_index
]
=
1
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
req_index
]
=
request
.
cached_len
self
.
cached_token_ids
[
req_index
]
=
request
.
cached_token_ids
self
.
cached_hidden_states
[
req_index
]
=
request
.
cached_hidden_states
self
.
cached_slot_mappings
[
req_index
]
=
request
.
cached_slot_mappings
self
.
cached_positions
[
req_index
]
=
request
.
cached_positions
# Add request lora ID
if
request
.
lora_request
:
lora_id
=
request
.
lora_request
.
lora_int_id
...
...
@@ -632,6 +688,20 @@ class InputBatch:
self
.
num_accepted_tokens_cpu
[
i1
],
)
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
i1
],
self
.
cached_len
[
i2
]
=
(
self
.
cached_len
[
i2
],
self
.
cached_len
[
i1
],
)
self
.
cached_token_ids
[[
i1
,
i2
],
...]
=
self
.
cached_token_ids
[[
i2
,
i1
],
...]
self
.
cached_hidden_states
[[
i1
,
i2
],
...]
=
self
.
cached_hidden_states
[
[
i2
,
i1
],
...
]
self
.
cached_slot_mappings
[[
i1
,
i2
],
...]
=
self
.
cached_slot_mappings
[
[
i2
,
i1
],
...
]
self
.
cached_positions
[[
i1
,
i2
],
...]
=
self
.
cached_positions
[[
i2
,
i1
],
...]
swap_dict_values
(
self
.
generators
,
i1
,
i2
)
swap_dict_values
(
self
.
bad_words_token_ids
,
i1
,
i2
)
...
...
@@ -769,6 +839,23 @@ class InputBatch:
if
bad_words_token_ids
is
not
None
:
self
.
bad_words_token_ids
[
empty_index
]
=
bad_words_token_ids
if
self
.
multi_layer_eagle_num
>
0
:
self
.
cached_len
[
empty_index
]
=
self
.
cached_len
[
last_req_index
]
self
.
cached_token_ids
[
empty_index
]
=
self
.
cached_token_ids
[
last_req_index
]
self
.
cached_hidden_states
[
empty_index
]
=
self
.
cached_hidden_states
[
last_req_index
]
self
.
cached_slot_mappings
[
empty_index
]
=
self
.
cached_slot_mappings
[
last_req_index
]
self
.
cached_positions
[
empty_index
]
=
self
.
cached_positions
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
84ac1d27
...
...
@@ -164,7 +164,8 @@ from vllm.v1.spec_decode.draft_model import DraftModelProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.extract_hidden_states
import
ExtractHiddenStatesProposer
from
vllm.v1.spec_decode.medusa
import
MedusaProposer
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
MultiLayerEagleMetadata
,
SpecDecodeMetadata
from
vllm.v1.spec_decode.multi_layer_eagle
import
MultiLayerEagleProposer
from
vllm.v1.spec_decode.ngram_proposer_gpu
import
(
NgramProposerGPU
,
copy_num_valid_draft_tokens
,
...
...
@@ -374,6 +375,7 @@ class ExecuteModelState(NamedTuple):
scheduler_output
:
"SchedulerOutput"
logits
:
torch
.
Tensor
spec_decode_metadata
:
SpecDecodeMetadata
|
None
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
spec_decode_common_attn_metadata
:
CommonAttentionMetadata
|
None
hidden_states
:
torch
.
Tensor
sample_hidden_states
:
torch
.
Tensor
...
...
@@ -500,6 +502,11 @@ class GPUModelRunner(
self
.
late_interaction_runner
=
LateInteractionRunner
()
self
.
use_aux_hidden_state_outputs
=
False
# multi layer eagle
self
.
enable_multi_layer_eagle
=
False
self
.
multi_layer_eagle_num
=
0
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
...
...
@@ -544,7 +551,17 @@ class GPUModelRunner(
elif
self
.
speculative_config
.
method
==
"suffix"
:
self
.
drafter
=
SuffixDecodingProposer
(
self
.
vllm_config
)
elif
self
.
speculative_config
.
use_eagle
():
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
if
(
self
.
speculative_config
.
enable_multi_layers_mtp
and
self
.
speculative_config
.
method
==
"mtp"
):
self
.
enable_multi_layer_eagle
=
True
self
.
drafter
=
MultiLayerEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
self
.
multi_layer_eagle_num
=
self
.
drafter
.
layer_num
else
:
self
.
drafter
=
EagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
if
self
.
speculative_config
.
method
==
"eagle3"
:
self
.
use_aux_hidden_state_outputs
=
(
self
.
drafter
.
eagle3_use_aux_hidden_state
...
...
@@ -623,6 +640,10 @@ class GPUModelRunner(
logitsprocs_need_output_token_ids
=
bool
(
custom_logitsprocs
),
is_pooling_model
=
self
.
is_pooling_model
,
cp_kv_cache_interleave_size
=
self
.
parallel_config
.
cp_kv_cache_interleave_size
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
# Separate cuda stream for overlapping transfer of sampled token ids from
...
...
@@ -1143,6 +1164,9 @@ class GPUModelRunner(
if
self
.
uses_xdrope_dim
>
0
:
self
.
_init_xdrope_positions
(
req_state
)
if
self
.
enable_multi_layer_eagle
:
self
.
_init_multi_layer_eagle_cache
(
req_state
)
reqs_to_add
.
append
(
req_state
)
# Track new requests for ngram_gpu full tensor copy
if
is_ngram_gpu
:
...
...
@@ -1442,6 +1466,24 @@ class GPUModelRunner(
req_state
.
mm_features
,
)
def
_init_multi_layer_eagle_cache
(
self
,
req_state
:
CachedRequestState
):
req_state
.
cached_len
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_hidden_states
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
self
.
model_config
.
get_hidden_size
(),
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
req_state
.
cached_token_ids
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
req_state
.
cached_positions
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
req_state
.
cached_slot_mappings
=
torch
.
zeros
(
self
.
multi_layer_eagle_num
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
_extract_mm_kwargs
(
self
,
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -1672,10 +1714,11 @@ class GPUModelRunner(
)
->
tuple
[
torch
.
Tensor
,
SpecDecodeMetadata
|
None
,
MultiLayerEagleMetadata
|
None
,
]:
"""
:return: tuple[
logits_indices, spec_decode_metadata,
logits_indices, spec_decode_metadata,
multi_layer_eagle_metadata,
]
"""
total_num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
...
...
@@ -1879,9 +1922,21 @@ class GPUModelRunner(
self
.
input_batch
,
num_scheduled_tokens
,
num_sampled_tokens
)
if
self
.
enable_multi_layer_eagle
:
multi_layer_eagle_metadata
=
MultiLayerEagleMetadata
(
cached_len
=
self
.
input_batch
.
cached_len
[:
num_reqs
],
cached_token_ids
=
self
.
input_batch
.
cached_token_ids
[:
num_reqs
],
cached_hidden_states
=
self
.
input_batch
.
cached_hidden_states
[:
num_reqs
],
cached_slot_mappings
=
self
.
input_batch
.
cached_slot_mappings
[:
num_reqs
],
cached_positions
=
self
.
input_batch
.
cached_positions
[:
num_reqs
],
)
else
:
multi_layer_eagle_metadata
=
None
return
(
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
)
def
_build_attention_metadata
(
...
...
@@ -3634,9 +3689,11 @@ class GPUModelRunner(
max_num_scheduled_tokens
=
int
(
num_scheduled_tokens_np
.
max
())
num_tokens_unpadded
=
scheduler_output
.
total_num_scheduled_tokens
logits_indices
,
spec_decode_metadata
=
self
.
_prepare_inputs
(
scheduler_output
,
num_scheduled_tokens_np
,
logits_indices
,
spec_decode_metadata
,
multi_layer_eagle_metadata
=
(
self
.
_prepare_inputs
(
scheduler_output
,
num_scheduled_tokens_np
,
)
)
cascade_attn_prefix_lens
=
None
...
...
@@ -3867,6 +3924,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3905,6 +3963,7 @@ class GPUModelRunner(
scheduler_output
,
logits
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
hidden_states
,
sample_hidden_states
,
...
...
@@ -3953,6 +4012,7 @@ class GPUModelRunner(
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
multi_layer_eagle_metadata
,
spec_decode_common_attn_metadata
,
slot_mappings
,
)
...
...
@@ -4242,6 +4302,7 @@ class GPUModelRunner(
sample_hidden_states
:
torch
.
Tensor
,
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
,
spec_decode_metadata
:
SpecDecodeMetadata
|
None
,
multi_layer_eagle_metadata
:
MultiLayerEagleMetadata
|
None
,
common_attn_metadata
:
CommonAttentionMetadata
,
slot_mappings
:
dict
[
str
,
torch
.
Tensor
]
|
list
[
dict
[
str
,
torch
.
Tensor
]]
|
None
,
)
->
list
[
list
[
int
]]
|
torch
.
Tensor
:
...
...
@@ -4466,6 +4527,7 @@ class GPUModelRunner(
token_indices_to_sample
=
token_indices_to_sample
,
sampling_metadata
=
sampling_metadata
,
common_attn_metadata
=
common_attn_metadata
,
multi_layer_eagle_metadata
=
multi_layer_eagle_metadata
,
mm_embed_inputs
=
mm_embed_inputs
,
num_rejected_tokens_gpu
=
num_rejected_tokens_gpu
,
slot_mappings
=
slot_mappings
,
...
...
@@ -6216,6 +6278,10 @@ class GPUModelRunner(
logitsprocs
=
self
.
input_batch
.
logitsprocs
,
logitsprocs_need_output_token_ids
=
self
.
input_batch
.
logitsprocs_need_output_token_ids
,
is_pooling_model
=
self
.
is_pooling_model
,
multi_layer_eagle_num
=
self
.
multi_layer_eagle_num
if
self
.
enable_multi_layer_eagle
else
0
,
hidden_size
=
self
.
model_config
.
get_hidden_size
(),
)
assert
self
.
_init_block_sizes
==
block_sizes
,
(
...
...
vllm/version.py
View file @
84ac1d27
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try
:
from
._version
import
__version__
,
__version_tuple__
__version__
=
"0.18.1"
__version_tuple__
=
(
0
,
18
,
1
)
__hcu_version__
=
f
'0.18.1+das.ori.dtk2604'
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
except
Exception
as
e
:
import
warnings
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
RuntimeWarning
,
stacklevel
=
2
)
warnings
.
warn
(
f
"Failed to read commit hash:
\n
+ str(e)"
,
RuntimeWarning
,
stacklevel
=
2
)
__version__
=
"dev"
__version_tuple__
=
(
0
,
0
,
__version__
)
def
_prev_minor_version_was
(
version_str
):
"""
Check whether a given version matches the previous minor version.
'''
Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
...
...
@@ -21,19 +24,19 @@ def _prev_minor_version_was(version_str):
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
'''
# Match anything if this is a dev tree
if
__version_tuple__
[
0
:
2
]
==
(
0
,
0
):
return
True
# Note - this won't do the right thing when we release 1.0!
assert
__version_tuple__
[
0
]
==
0
#
assert __version_tuple__[0] == 0
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
version_str
==
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
def
_prev_minor_version
():
"""
For the purpose of testing, return a previous minor version number.
"""
'''
For the purpose of testing, return a previous minor version number.
'''
# In dev tree, this will return "0.-1", but that will work fine"
assert
isinstance
(
__version_tuple__
[
1
],
int
)
return
f
"
{
__version_tuple__
[
0
]
}
.
{
__version_tuple__
[
1
]
-
1
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment