Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
751c492c
"vscode:/vscode.git/clone" did not exist on "ac4cc84efe37e31c613bdd4c40b7ceb99e9e403c"
Commit
751c492c
authored
Jul 29, 2025
by
zhuwenwen
Browse files
GLM-4.5 Model Support
parent
98958aed
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1991 additions
and
9 deletions
+1991
-9
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+1
-1
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+1
-0
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+7
-0
tests/tool_use/test_glm4_moe_tool_parser.py
tests/tool_use/test_glm4_moe_tool_parser.py
+405
-0
vllm/config.py
vllm/config.py
+12
-3
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+17
-5
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
+402
-0
vllm/model_executor/models/glm4_moe.py
vllm/model_executor/models/glm4_moe.py
+685
-0
vllm/model_executor/models/glm4_moe_mtp.py
vllm/model_executor/models/glm4_moe_mtp.py
+307
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+2
-0
vllm/reasoning/glm4_moe_reasoning_parser.py
vllm/reasoning/glm4_moe_reasoning_parser.py
+148
-0
vllm/worker/worker.py
vllm/worker/worker.py
+1
-0
No files found.
benchmarks/kernels/benchmark_moe.py
View file @
751c492c
...
@@ -628,7 +628,7 @@ def main(args: argparse.Namespace):
...
@@ -628,7 +628,7 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
):
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"Glm4MoeForCausalLM"
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
intermediate_size
=
config
.
moe_intermediate_size
...
...
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
751c492c
...
@@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
...
@@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
elif
(
elif
(
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"DeepseekV2ForCausalLM"
or
config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
):
):
E
=
config
.
n_routed_experts
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
topk
=
config
.
num_experts_per_tok
...
...
docs/models/supported_models.md
View file @
751c492c
...
@@ -567,6 +567,7 @@ Specified using `--task generate`.
...
@@ -567,6 +567,7 @@ Specified using `--task generate`.
|
`Gemma3ForConditionalGeneration`
| Gemma 3 | T + I
<sup>
+
</sup>
|
`google/gemma-3-4b-it`
,
`google/gemma-3-27b-it`
, etc. | ✅︎ | ✅︎ | ⚠️ |
|
`Gemma3ForConditionalGeneration`
| Gemma 3 | T + I
<sup>
+
</sup>
|
`google/gemma-3-4b-it`
,
`google/gemma-3-27b-it`
, etc. | ✅︎ | ✅︎ | ⚠️ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`THUDM/glm-4v-9b`
,
`THUDM/cogagent-9b-20241220`
etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GLM4VForCausalLM`
<sup>
^
</sup>
| GLM-4V | T + I |
`THUDM/glm-4v-9b`
,
`THUDM/cogagent-9b-20241220`
etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`THUDM/GLM-4.1V-9B-Thinkg`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4vForConditionalGeneration`
| GLM-4.1V-Thinking | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`THUDM/GLM-4.1V-9B-Thinkg`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Glm4MoeForCausalLM`
| GLM-4.5 | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`THUDM/GLM-4.5`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`GraniteSpeechForConditionalGeneration`
| Granite Speech | T + A |
`ibm-granite/granite-speech-3.3-8b`
| ✅︎ | ✅︎ | ✅︎ |
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎
\*
|
|
`H2OVLChatModel`
| H2OVL | T + I
<sup>
E+
</sup>
|
`h2oai/h2ovl-mississippi-800m`
,
`h2oai/h2ovl-mississippi-2b`
, etc. | | ✅︎ | ✅︎
\*
|
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
etc. | ✅︎ | | ✅︎ |
|
`Idefics3ForConditionalGeneration`
| Idefics3 | T + I |
`HuggingFaceM4/Idefics3-8B-Llama3`
etc. | ✅︎ | | ✅︎ |
...
...
tests/models/registry.py
View file @
751c492c
...
@@ -351,6 +351,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -351,6 +351,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
,
trust_remote_code
=
True
,
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]}),
# noqa: E501
"Glm4vForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"THUDM/GLM-4.1V-9B-Thinking"
),
min_transformers_version
=
"4.53"
),
# noqa: E501
"Glm4vForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"THUDM/GLM-4.1V-9B-Thinking"
),
min_transformers_version
=
"4.53"
),
# noqa: E501
"Glm4MoeForCausalLM"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"THUDM/GLM-4.5"
),
min_transformers_version
=
"4.54"
,
is_available_online
=
False
),
# noqa: E501
"H2OVLChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"h2oai/h2ovl-mississippi-800m"
),
"H2OVLChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"h2oai/h2ovl-mississippi-800m"
),
extras
=
{
"2b"
:
os
.
path
.
join
(
models_path_prefix
,
"h2oai/h2ovl-mississippi-2b"
)},
# noqa: E501
extras
=
{
"2b"
:
os
.
path
.
join
(
models_path_prefix
,
"h2oai/h2ovl-mississippi-2b"
)},
# noqa: E501
max_transformers_version
=
"4.48"
,
# noqa: E501
max_transformers_version
=
"4.48"
,
# noqa: E501
...
@@ -460,6 +463,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
...
@@ -460,6 +463,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online
=
False
,
is_available_online
=
False
,
speculative_model
=
os
.
path
.
join
(
models_path_prefix
,
"openbmb/MiniCPM-2B-sft-bf16"
),
speculative_model
=
os
.
path
.
join
(
models_path_prefix
,
"openbmb/MiniCPM-2B-sft-bf16"
),
tokenizer
=
os
.
path
.
join
(
models_path_prefix
,
"openbmb/MiniCPM-2B-sft-bf16"
)),
tokenizer
=
os
.
path
.
join
(
models_path_prefix
,
"openbmb/MiniCPM-2B-sft-bf16"
)),
"Glm4MoeMTPModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"THUDM/GLM-4.5"
),
speculative_model
=
os
.
path
.
join
(
models_path_prefix
,
"THUDM/GLM-4.5"
),
min_transformers_version
=
"4.54"
,
is_available_online
=
False
),
"MiMoMTPModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
),
"MiMoMTPModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
),
trust_remote_code
=
True
,
trust_remote_code
=
True
,
speculative_model
=
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
))
speculative_model
=
os
.
path
.
join
(
models_path_prefix
,
"XiaomiMiMo/MiMo-7B-RL"
))
...
...
tests/tool_use/test_glm4_moe_tool_parser.py
0 → 100644
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
json
import
pytest
from
vllm.entrypoints.openai.protocol
import
FunctionCall
,
ToolCall
from
vllm.entrypoints.openai.tool_parsers
import
Glm4MoeModelToolParser
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
pytest
.
skip
(
"skip glm4_moe parser test"
,
allow_module_level
=
True
)
# Use a common model that is likely to be available
MODEL
=
"THUDM/GLM-4.5"
@
pytest
.
fixture
(
scope
=
"module"
)
def
glm4_moe_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
MODEL
)
@
pytest
.
fixture
def
glm4_moe_tool_parser
(
glm4_moe_tokenizer
):
return
Glm4MoeModelToolParser
(
glm4_moe_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
],
expected_tool_calls
:
list
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
assert
isinstance
(
actual_tool_call
.
id
,
str
)
assert
len
(
actual_tool_call
.
id
)
>
0
assert
actual_tool_call
.
type
==
"function"
assert
actual_tool_call
.
function
.
name
==
expected_tool_call
.
function
.
name
# Compare arguments as JSON objects to handle formatting differences
actual_args
=
json
.
loads
(
actual_tool_call
.
function
.
arguments
)
expected_args
=
json
.
loads
(
expected_tool_call
.
function
.
arguments
)
assert
actual_args
==
expected_args
def
test_extract_tool_calls_no_tools
(
glm4_moe_tool_parser
):
model_output
=
"This is a test"
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_call"
,
"multiple_tool_calls"
,
"tool_call_with_content_before"
,
"tool_call_with_mixed_args"
,
"tool_call_with_chinese_content"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Dallas</arg_value>
<arg_key>state</arg_key>
<arg_value>TX</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
,
}),
))
],
None
,
),
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Dallas</arg_value>
<arg_key>state</arg_key>
<arg_value>TX</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>
<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Orlando</arg_value>
<arg_key>state</arg_key>
<arg_value>FL</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Dallas"
,
"state"
:
"TX"
,
"unit"
:
"fahrenheit"
,
}),
)),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Orlando"
,
"state"
:
"FL"
,
"unit"
:
"fahrenheit"
,
}),
)),
],
None
,
),
(
"""I'll help you check the weather. <tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>Seattle</arg_value>
<arg_key>state</arg_key>
<arg_value>WA</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Seattle"
,
"state"
:
"WA"
,
"unit"
:
"celsius"
,
}),
))
],
"I'll help you check the weather."
,
),
(
"""<tool_call>get_current_weather
<arg_key>city</arg_key>
<arg_value>New York</arg_value>
<arg_key>state</arg_key>
<arg_value>NY</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"New York"
,
"state"
:
"NY"
,
"unit"
:
"celsius"
,
}),
))
],
None
,
),
(
"""I will help you get the weather.<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_weather"
,
arguments
=
json
.
dumps
({
"city"
:
"Beijing"
,
"date"
:
"2025-08-01"
,
}),
))
],
"I will help you get the weather."
),
],
)
def
test_extract_tool_calls
(
glm4_moe_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_with_thinking_tags
(
glm4_moe_tool_parser
):
"""Test tool extraction when thinking tags are present."""
model_output
=
"""<think>I want to get the weather.</think>
I will help you get the weather.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
expected_content
=
"""<think>I want to get the weather.</think>
I will help you get the weather."""
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_malformed_xml
(
glm4_moe_tool_parser
):
"""Test that malformed XML is handled gracefully."""
model_output
=
"""<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Seattle</arg_value>
<arg_key>incomplete_arg
<arg_value>value</arg_value>
</tool_call>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
# Should handle malformed XML gracefully
# The parser should either extract what it can or return no tool calls
# depending on how robust we want the parsing to be
assert
isinstance
(
extracted_tool_calls
.
tools_called
,
bool
)
assert
isinstance
(
extracted_tool_calls
.
tool_calls
,
list
)
def
test_extract_tool_calls_empty_arguments
(
glm4_moe_tool_parser
):
"""Test tool calls with no arguments."""
model_output
=
"""<tool_call>get_current_time
</tool_call>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_current_time"
# Empty arguments should result in empty JSON object
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
arguments
==
"{}"
def
test_extract_tool_calls_mixed_content
(
glm4_moe_tool_parser
):
"""Test extraction with mixed content and multiple tool calls."""
model_output
=
"""I will help you get the weather info.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>
meaningwhile, I will also check the weather in Shanghai.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Shanghai</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>
</tool_call>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
2
# Check first tool call
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"get_weather"
args1
=
json
.
loads
(
extracted_tool_calls
.
tool_calls
[
0
].
function
.
arguments
)
assert
args1
[
"city"
]
==
"Beijing"
assert
args1
[
"date"
]
==
"2025-08-01"
# Check second tool call
assert
extracted_tool_calls
.
tool_calls
[
1
].
function
.
name
==
"get_weather"
args2
=
json
.
loads
(
extracted_tool_calls
.
tool_calls
[
1
].
function
.
arguments
)
assert
args2
[
"city"
]
==
"Shanghai"
assert
args2
[
"date"
]
==
"2025-08-01"
# Content should be everything before the first tool call
assert
extracted_tool_calls
.
content
==
"I will help you get the weather info."
def
test_streaming_basic_functionality
(
glm4_moe_tool_parser
):
"""Test basic streaming functionality."""
# Reset streaming state
glm4_moe_tool_parser
.
current_tool_name_sent
=
False
glm4_moe_tool_parser
.
prev_tool_call_arr
=
[]
glm4_moe_tool_parser
.
current_tool_id
=
-
1
glm4_moe_tool_parser
.
streamed_args_for_tool
=
[]
# Test with a simple tool call
current_text
=
"""<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
</tool_call>"""
# Mock token IDs for testing
tool_call_start_id
=
glm4_moe_tool_parser
.
tool_call_start_token_id
or
12345
tool_call_end_id
=
glm4_moe_tool_parser
.
tool_call_end_token_id
or
12346
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
current_text
,
delta_text
=
"</tool_call>"
,
previous_token_ids
=
[],
current_token_ids
=
[
tool_call_start_id
,
tool_call_end_id
],
delta_token_ids
=
[
tool_call_end_id
],
request
=
None
,
)
# The result behavior depends on the streaming state
# This test mainly ensures no exceptions are thrown
assert
result
is
None
or
hasattr
(
result
,
'tool_calls'
)
or
hasattr
(
result
,
'content'
)
def
test_streaming_no_tool_calls
(
glm4_moe_tool_parser
):
"""Test streaming when there are no tool calls."""
current_text
=
"This is just regular text without any tool calls."
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"This is just regular text"
,
current_text
=
current_text
,
delta_text
=
" without any tool calls."
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
# Should return the delta text as content
assert
result
is
not
None
assert
hasattr
(
result
,
'content'
)
assert
result
.
content
==
" without any tool calls."
def
test_streaming_with_content_before_tool_calls
(
glm4_moe_tool_parser
):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
glm4_moe_tool_parser
.
current_tool_name_sent
=
False
glm4_moe_tool_parser
.
prev_tool_call_arr
=
[]
glm4_moe_tool_parser
.
current_tool_id
=
-
1
glm4_moe_tool_parser
.
streamed_args_for_tool
=
[]
current_text
=
"I will help you get the weather<tool_call>"
result
=
glm4_moe_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"I will help you"
,
current_text
=
current_text
,
delta_text
=
"get the weather.<tool_call>"
,
previous_token_ids
=
[],
current_token_ids
=
[],
delta_token_ids
=
[],
request
=
None
,
)
# Should return content when no tool call tokens are detected
assert
result
is
not
None
assert
hasattr
(
result
,
'content'
)
assert
result
.
content
==
"get the weather.<tool_call>"
def
test_extract_tool_calls_special_characters
(
glm4_moe_tool_parser
):
"""Test tool calls with special characters and unicode."""
model_output
=
"""<tool_call>send_message
<arg_key>recipient</arg_key>
<arg_value>Amy</arg_value>
<arg_key>message</arg_key>
<arg_value>It is a nice day</arg_value>
<arg_key>priority</arg_key>
<arg_value>high</arg_value>
</tool_call>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert
len
(
extracted_tool_calls
.
tool_calls
)
==
1
assert
extracted_tool_calls
.
tool_calls
[
0
].
function
.
name
==
"send_message"
args
=
json
.
loads
(
extracted_tool_calls
.
tool_calls
[
0
].
function
.
arguments
)
assert
args
[
"recipient"
]
==
"Amy"
assert
args
[
"message"
]
==
"It is a nice day"
assert
args
[
"priority"
]
==
"high"
def
test_extract_tool_calls_incomplete_tool_call
(
glm4_moe_tool_parser
):
"""Test incomplete tool calls (missing closing tag)."""
model_output
=
"""<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2025-08-01</arg_value>"""
extracted_tool_calls
=
glm4_moe_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
# Incomplete tool calls should not be extracted
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
\ No newline at end of file
vllm/config.py
View file @
751c492c
...
@@ -1260,7 +1260,8 @@ class ModelConfig:
...
@@ -1260,7 +1260,8 @@ class ModelConfig:
self
,
parallel_config
:
"ParallelConfig"
)
->
tuple
[
int
,
int
]:
self
,
parallel_config
:
"ParallelConfig"
)
->
tuple
[
int
,
int
]:
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.distributed.utils
import
get_pp_indices
if
(
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
if
(
self
.
hf_text_config
.
model_type
==
"deepseek_mtp"
or
self
.
hf_config
.
model_type
==
"mimo_mtp"
):
or
self
.
hf_config
.
model_type
==
"mimo_mtp"
or
self
.
hf_config
.
model_type
==
"glm4_moe_mtp"
):
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
total_num_hidden_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
0
)
"num_nextn_predict_layers"
,
0
)
else
:
else
:
...
@@ -2595,7 +2596,15 @@ class SpeculativeConfig:
...
@@ -2595,7 +2596,15 @@ class SpeculativeConfig:
"n_predict"
:
n_predict
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"MiMoMTPModel"
]
"architectures"
:
[
"MiMoMTPModel"
]
})
})
return
hf_config
if
hf_config
.
architectures
[
0
]
==
"Glm4MoeForCausalLM"
:
hf_config
.
model_type
=
"glm4_moe_mtp"
n_predict
=
getattr
(
hf_config
,
"num_nextn_predict_layers"
,
None
)
hf_config
.
update
({
"num_hidden_layers"
:
0
,
"n_predict"
:
n_predict
,
"architectures"
:
[
"Glm4MoeMTPModel"
]
})
return
hf_config
return
hf_config
...
@@ -2706,7 +2715,7 @@ class SpeculativeConfig:
...
@@ -2706,7 +2715,7 @@ class SpeculativeConfig:
"mlp_speculator"
):
"mlp_speculator"
):
self
.
method
=
"mlp_speculator"
self
.
method
=
"mlp_speculator"
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
elif
(
self
.
draft_model_config
.
hf_config
.
model_type
==
"deepseek_mtp"
):
"deepseek_mtp"
,
"glm4_moe_mtp"
):
self
.
method
=
"deepseek_mtp"
self
.
method
=
"deepseek_mtp"
if
self
.
num_speculative_tokens
>
1
:
if
self
.
num_speculative_tokens
>
1
:
logger
.
warning
(
logger
.
warning
(
...
...
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
751c492c
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
.abstract_tool_parser
import
ToolParser
,
ToolParserManager
from
.abstract_tool_parser
import
ToolParser
,
ToolParserManager
from
.deepseekv3_tool_parser
import
DeepSeekV3ToolParser
from
.deepseekv3_tool_parser
import
DeepSeekV3ToolParser
from
.glm4_moe_tool_parser
import
Glm4MoeModelToolParser
from
.granite_20b_fc_tool_parser
import
Granite20bFCToolParser
from
.granite_20b_fc_tool_parser
import
Granite20bFCToolParser
from
.granite_tool_parser
import
GraniteToolParser
from
.granite_tool_parser
import
GraniteToolParser
from
.hermes_tool_parser
import
Hermes2ProToolParser
from
.hermes_tool_parser
import
Hermes2ProToolParser
...
@@ -17,9 +18,20 @@ from .pythonic_tool_parser import PythonicToolParser
...
@@ -17,9 +18,20 @@ from .pythonic_tool_parser import PythonicToolParser
from
.xlam_tool_parser
import
xLAMToolParser
from
.xlam_tool_parser
import
xLAMToolParser
__all__
=
[
__all__
=
[
"ToolParser"
,
"ToolParserManager"
,
"Granite20bFCToolParser"
,
"ToolParser"
,
"GraniteToolParser"
,
"Hermes2ProToolParser"
,
"MistralToolParser"
,
"ToolParserManager"
,
"Internlm2ToolParser"
,
"Llama3JsonToolParser"
,
"JambaToolParser"
,
"Granite20bFCToolParser"
,
"Llama4PythonicToolParser"
,
"PythonicToolParser"
,
"Phi4MiniJsonToolParser"
,
"GraniteToolParser"
,
"DeepSeekV3ToolParser"
,
"xLAMToolParser"
,
"MinimaxToolParser"
"Hermes2ProToolParser"
,
"MistralToolParser"
,
"Internlm2ToolParser"
,
"Llama3JsonToolParser"
,
"JambaToolParser"
,
"Llama4PythonicToolParser"
,
"PythonicToolParser"
,
"Phi4MiniJsonToolParser"
,
"DeepSeekV3ToolParser"
,
"xLAMToolParser"
,
"MinimaxToolParser"
,
"Glm4MoeModelToolParser"
,
]
]
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
0 → 100644
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code modified from deepseekv3_tool_parser.py
from
collections.abc
import
Sequence
from
typing
import
Union
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
(
"glm4_moe"
)
class
Glm4MoeModelToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
self
.
current_tool_name_sent
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
=
-
1
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
self
.
tool_call_start_token
=
"<tool_call>"
self
.
tool_call_end_token
=
"</tool_call>"
self
.
tool_calls_start_token
=
self
.
tool_call_start_token
# Updated regex for the XML-based format
self
.
tool_call_regex
=
re
.
compile
(
r
"<tool_call>\s*"
r
"(?P<function_name>[^\n<]+)\s*"
# 函数名(到换行或 <)
r
"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
r
"<arg_value>[^<]*</arg_value>\s*)*)\s*"
r
"</tool_call>"
,
re
.
DOTALL
,
)
# Regex for parsing individual arguments
self
.
arg_regex
=
re
.
compile
(
r
"<arg_key>(?P<key>[^<]+)</arg_key>\s*<arg_value>(?P<value>[^<]*)</arg_value>"
,
re
.
DOTALL
,
)
# Streaming regex
self
.
stream_tool_call_portion_regex
=
re
.
compile
(
r
"(?P<function_name>[^\n<]+)\s*"
r
"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
r
"<arg_value>[^<]*</arg_value>\s*)*)"
,
re
.
DOTALL
,
)
# For streaming, we also need a regex to match just the function name
self
.
stream_tool_call_name_regex
=
re
.
compile
(
r
"(?P<function_name>[^\n<]+)"
,
re
.
DOTALL
,
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
def
_parse_arguments
(
self
,
args_text
:
str
)
->
str
:
"""Parse XML-based arguments into JSON format."""
if
not
args_text
or
not
args_text
.
strip
():
return
"{}"
args_dict
=
{}
matches
=
self
.
arg_regex
.
findall
(
args_text
)
for
key
,
value
in
matches
:
args_dict
[
key
.
strip
()]
=
value
.
strip
()
import
json
return
json
.
dumps
(
args_dict
,
ensure_ascii
=
False
)
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
# sanity check; avoid unnecessary processing
if
self
.
tool_calls_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
# Find all tool calls in the output
function_call_matches
=
self
.
tool_call_regex
.
findall
(
model_output
)
logger
.
debug
(
"function_call_matches: %s"
,
function_call_matches
)
if
not
function_call_matches
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
,
)
tool_calls
=
[]
for
i
,
match
in
enumerate
(
function_call_matches
):
function_name
,
function_args_xml
=
match
function_name
=
function_name
.
strip
()
# Parse XML arguments to JSON
function_args_json
=
self
.
_parse_arguments
(
function_args_xml
)
tool_calls
.
append
(
ToolCall
(
id
=
f
"call_
{
i
}
"
,
type
=
'function'
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
function_args_json
),
))
# Extract content before the first tool call
content
=
model_output
[:
model_output
.
find
(
self
.
tool_calls_start_token
)]
return
ExtractedToolCallInformation
(
tools_called
=
bool
(
tool_calls
),
tool_calls
=
tool_calls
,
content
=
content
.
strip
()
if
content
.
strip
()
else
None
,
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
logger
.
debug
(
"delta_text: %s"
,
delta_text
)
logger
.
debug
(
"delta_token_ids: %s"
,
delta_token_ids
)
# check to see if we should be streaming a tool call - is there a
if
self
.
tool_call_start_token_id
not
in
current_token_ids
:
logger
.
debug
(
"No tool call tokens found!"
)
return
DeltaMessage
(
content
=
delta_text
)
delta_text
=
delta_text
.
replace
(
self
.
tool_calls_start_token
,
""
).
replace
(
self
.
tool_call_end_token
,
""
)
try
:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count
=
previous_token_ids
.
count
(
self
.
tool_call_start_token_id
)
prev_tool_end_count
=
previous_token_ids
.
count
(
self
.
tool_call_end_token_id
)
cur_tool_start_count
=
current_token_ids
.
count
(
self
.
tool_call_start_token_id
)
cur_tool_end_count
=
current_token_ids
.
count
(
self
.
tool_call_end_token_id
)
tool_call_portion
=
None
text_portion
=
None
# case: if we're generating text, OR rounding out a tool call
if
(
cur_tool_start_count
==
cur_tool_end_count
and
prev_tool_end_count
==
cur_tool_end_count
and
self
.
tool_call_end_token
not
in
delta_text
):
logger
.
debug
(
"Generating text content! skipping tool parsing."
)
return
DeltaMessage
(
content
=
delta_text
)
if
self
.
tool_call_end_token
in
delta_text
:
logger
.
debug
(
"tool_call_end_token in delta_text"
)
full_text
=
current_text
+
delta_text
tool_call_portion
=
full_text
.
split
(
self
.
tool_call_start_token
)[
-
1
].
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
delta_text
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
0
].
rstrip
()
text_portion
=
delta_text
.
split
(
self
.
tool_call_end_token
)[
-
1
].
lstrip
()
# case -- we're starting a new tool call
if
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
>
prev_tool_start_count
):
if
len
(
delta_token_ids
)
>
1
:
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
else
:
tool_call_portion
=
None
delta
=
None
text_portion
=
None
# set cursors and state appropriately
self
.
current_tool_id
+=
1
self
.
current_tool_name_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"Starting on a new tool %s"
,
self
.
current_tool_id
)
# case -- we're updating an existing tool call
elif
(
cur_tool_start_count
>
cur_tool_end_count
and
cur_tool_start_count
==
prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion
=
current_text
.
split
(
self
.
tool_call_start_token
)[
-
1
]
text_portion
=
None
# case -- the current tool call is being closed.
elif
(
cur_tool_start_count
==
cur_tool_end_count
and
cur_tool_end_count
>=
prev_tool_end_count
):
if
self
.
prev_tool_call_arr
is
None
or
len
(
self
.
prev_tool_call_arr
)
==
0
:
logger
.
debug
(
"attempting to close tool call, but no tool call"
)
return
None
diff
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
if
diff
:
diff
=
(
diff
.
encode
(
"utf-8"
).
decode
(
"unicode_escape"
)
if
diff
is
str
else
diff
)
if
'"}'
not
in
delta_text
:
return
None
end_loc
=
delta_text
.
rindex
(
'"}'
)
diff
=
delta_text
[:
end_loc
]
+
'"}'
logger
.
debug
(
"Finishing tool and found diff that had not "
"been streamed yet: %s"
,
diff
,
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
])
# case -- otherwise we're just generating text
else
:
text
=
delta_text
.
replace
(
self
.
tool_call_start_token
,
""
)
text
=
text
.
replace
(
self
.
tool_call_end_token
,
""
)
delta
=
DeltaMessage
(
tool_calls
=
[],
content
=
text
)
return
delta
current_tool_call
=
dict
()
if
tool_call_portion
:
current_tool_call_matches
=
(
self
.
stream_tool_call_portion_regex
.
match
(
tool_call_portion
))
if
current_tool_call_matches
:
tool_id
,
tool_args
=
(
current_tool_call_matches
.
groups
())
tool_name
=
tool_id
.
split
(
'.'
)[
1
].
split
(
':'
)[
0
]
current_tool_call
[
'id'
]
=
tool_id
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
tool_args
else
:
current_tool_call_name_matches
=
(
self
.
stream_tool_call_name_regex
.
match
(
tool_call_portion
))
if
current_tool_call_name_matches
:
tool_id_str
,
=
current_tool_call_name_matches
.
groups
()
tool_name
=
tool_id_str
.
split
(
'.'
)[
1
].
split
(
':'
)[
0
]
current_tool_call
[
'id'
]
=
tool_id_str
current_tool_call
[
"name"
]
=
tool_name
current_tool_call
[
"arguments"
]
=
""
else
:
logger
.
debug
(
"Not enough token"
)
return
None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if
not
self
.
current_tool_name_sent
:
if
current_tool_call
is
None
:
return
None
function_name
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"name"
)
tool_id
=
current_tool_call
.
get
(
"id"
)
if
function_name
:
self
.
current_tool_name_sent
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
tool_id
,
function
=
DeltaFunctionCall
(
name
=
function_name
).
model_dump
(
exclude_none
=
True
),
)
])
else
:
return
None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if
tool_call_portion
is
None
:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta
=
(
DeltaMessage
(
content
=
delta_text
)
if
text_portion
is
not
None
else
None
)
return
delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger
.
debug
(
"Trying to parse current tool call with ID %s"
,
self
.
current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
logger
.
debug
(
"diffing old arguments: %s"
,
prev_arguments
)
logger
.
debug
(
"against new ones: %s"
,
cur_arguments
)
# case -- no arguments have been created yet. skip sending a delta.
if
not
cur_arguments
and
not
prev_arguments
:
logger
.
debug
(
"Skipping text %s - no arguments"
,
delta_text
)
delta
=
None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta
=
None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif
cur_arguments
and
not
prev_arguments
:
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
cur_arguments
).
model_dump
(
exclude_none
=
True
),
)
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
cur_arguments
# last case -- we have an update to existing arguments.
elif
cur_arguments
and
prev_arguments
:
if
(
isinstance
(
delta_text
,
str
)
and
cur_arguments
!=
prev_arguments
and
len
(
cur_arguments
)
>
len
(
prev_arguments
)
and
cur_arguments
.
startswith
(
prev_arguments
)):
delta_arguments
=
cur_arguments
[
len
(
prev_arguments
):]
logger
.
debug
(
"got diff %s"
,
delta_text
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
delta_arguments
).
model_dump
(
exclude_none
=
True
),
)
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
cur_arguments
else
:
delta
=
None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if
self
.
current_tool_id
==
len
(
self
.
prev_tool_call_arr
)
-
1
:
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
current_tool_call
else
:
self
.
prev_tool_call_arr
.
append
(
current_tool_call
)
return
delta
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
return
None
# do not stream a delta. skip this token ID.
\ No newline at end of file
vllm/model_executor/models/glm4_moe.py
0 → 100644
View file @
751c492c
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/glm4_moe_mtp.py
0 → 100644
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The ZhipuAI Team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.glm4_moe
import
Glm4MoeDecoderLayer
,
get_spec_layer_idx_from_weight_name
from
.interfaces
import
SupportsPP
from
.utils
import
maybe_prefix
class
SharedHead
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
norm
(
hidden_states
)
class
Glm4MoeMultiTokenPredictorLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
prefix
:
str
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
config
.
hidden_size
,
bias
=
False
)
self
.
shared_head
=
SharedHead
(
config
=
config
,
quant_config
=
quant_config
)
self
.
mtp_block
=
Glm4MoeDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
=
self
.
enorm
(
inputs_embeds
)
previous_hidden_states
=
self
.
hnorm
(
previous_hidden_states
)
hidden_states
=
self
.
eh_proj
(
torch
.
cat
([
inputs_embeds
,
previous_hidden_states
],
dim
=-
1
))
hidden_states
,
residual
=
self
.
mtp_block
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
None
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Glm4MoeMultiTokenPredictor
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
mtp_start_layer_idx
=
config
.
num_hidden_layers
self
.
num_mtp_layers
=
config
.
num_nextn_predict_layers
# to map the exact layer index from weights
self
.
layers
=
torch
.
nn
.
ModuleDict
({
str
(
idx
):
Glm4MoeMultiTokenPredictorLayer
(
config
,
f
"
{
prefix
}
.layers.
{
idx
}
"
,
cache_config
=
vllm_config
.
cache_config
,
quant_config
=
vllm_config
.
quant_config
,
)
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
)
})
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
(
spec_step_idx
%
self
.
num_mtp_layers
)
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
current_step_idx
,
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
current_step_idx
=
(
spec_step_idx
%
self
.
num_mtp_layers
)
mtp_layer
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)]
logits
=
self
.
logits_processor
(
mtp_layer
.
shared_head
.
head
,
mtp_layer
.
shared_head
(
hidden_states
),
sampling_metadata
)
return
logits
class
Glm4MoeMTP
(
nn
.
Module
,
SupportsPP
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
model
=
Glm4MoeMultiTokenPredictor
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
previous_hidden_states
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
previous_hidden_states
,
inputs_embeds
,
spec_step_idx
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
spec_step_idx
:
int
=
0
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
,
spec_step_idx
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
)
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
spec_layer
=
get_spec_layer_idx_from_weight_name
(
self
.
config
,
name
)
if
spec_layer
is
None
:
continue
name
=
self
.
_rewrite_spec_layer_name
(
spec_layer
,
name
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
((
"mlp.experts."
in
name
)
and
name
not
in
params_dict
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if
(
spec_layer
!=
self
.
model
.
mtp_start_layer_idx
and
".layers"
not
in
name
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
def
_rewrite_spec_layer_name
(
self
,
spec_layer
:
int
,
name
:
str
)
->
str
:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
spec_layer_weight_names
=
[
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
]
shared_weight_names
=
[
"embed_tokens"
]
spec_layer_weight
=
False
shared_weight
=
False
for
weight_name
in
spec_layer_weight_names
:
if
weight_name
in
name
:
spec_layer_weight
=
True
if
weight_name
in
shared_weight_names
:
shared_weight
=
True
break
if
not
spec_layer_weight
:
# treat rest weights as weights for transformer layer block
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
elif
shared_weight
:
# treat shared weights as top level weights
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
"model."
)
return
name
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
751c492c
...
@@ -66,6 +66,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -66,6 +66,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma3nForConditionalGeneration"
:
(
"gemma3n"
,
"Gemma3nForConditionalGeneration"
),
# noqa: E501
"Gemma3nForConditionalGeneration"
:
(
"gemma3n"
,
"Gemma3nForConditionalGeneration"
),
# noqa: E501
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"GlmForCausalLM"
:
(
"glm"
,
"GlmForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
"Glm4ForCausalLM"
:
(
"glm4"
,
"Glm4ForCausalLM"
),
"Glm4MoeForCausalLM"
:
(
"glm4_moe"
,
"Glm4MoeForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
"GPTJForCausalLM"
:
(
"gpt_j"
,
"GPTJForCausalLM"
),
...
@@ -245,6 +246,7 @@ _SPECULATIVE_DECODING_MODELS = {
...
@@ -245,6 +246,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"EagleMiniCPMForCausalLM"
:
(
"minicpm_eagle"
,
"EagleMiniCPMForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"Eagle3LlamaForCausalLM"
:
(
"llama_eagle3"
,
"Eagle3LlamaForCausalLM"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"DeepSeekMTPModel"
:
(
"deepseek_mtp"
,
"DeepSeekMTP"
),
"Glm4MoeMTPModel"
:
(
"glm4_moe_mtp"
,
"Glm4MoeMTP"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MedusaModel"
:
(
"medusa"
,
"Medusa"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
"MLPSpeculatorPreTrainedModel"
:
(
"mlp_speculator"
,
"MLPSpeculator"
),
}
}
...
...
vllm/reasoning/__init__.py
View file @
751c492c
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.glm4_moe_reasoning_parser
import
Glm4MoeModelReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
...
@@ -12,4 +13,5 @@ __all__ = [
...
@@ -12,4 +13,5 @@ __all__ = [
"DeepSeekR1ReasoningParser"
,
"DeepSeekR1ReasoningParser"
,
"GraniteReasoningParser"
,
"GraniteReasoningParser"
,
"Qwen3ReasoningParser"
,
"Qwen3ReasoningParser"
,
"Glm4MoeModelReasoningParser"
,
]
]
vllm/reasoning/glm4_moe_reasoning_parser.py
0 → 100644
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"glm4_moe"
)
class
Glm4MoeModelReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for the Glm4MoeModel model.
The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning
text within its output. The model provides a strict switch to disable
reasoning output via the 'enable_thinking=False' parameter. This parser
extracts the reasoning content enclosed by <think> and </think> tokens
from the model's output.
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_start_token
=
"<think>"
self
.
think_end_token
=
"</think>"
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
think_start_token_id
=
self
.
vocab
.
get
(
self
.
think_start_token
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
(
self
.
think_start_token_id
is
None
or
self
.
think_end_token_id
is
None
):
raise
RuntimeError
(
"Glm4MoeModel reasoning parser could not locate "
"think start/end tokens in the tokenizer!"
)
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
think_end_token_id
in
input_ids
def
extract_content_ids
(
self
,
input_ids
:
list
[
int
])
->
list
[
int
]:
"""
Extract the content after the end tokens
"""
if
self
.
think_end_token_id
not
in
input_ids
[:
-
1
]:
return
[]
else
:
return
input_ids
[
input_ids
.
index
(
self
.
think_end_token_id
)
+
1
:]
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if
len
(
delta_token_ids
)
==
1
and
(
delta_token_ids
[
0
]
in
[
self
.
think_start_token_id
,
self
.
think_end_token_id
]):
return
None
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# <think> in previous, </think> in previous,
# reasoning content continues
return
DeltaMessage
(
content
=
delta_text
)
else
:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_start_token_id
in
delta_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[
start_index
+
len
(
self
.
think_start_token
):
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
else
:
# <think> in delta, no </think> in delta,
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
else
:
# thinking is disabled, just content
return
DeltaMessage
(
content
=
delta_text
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""
Extract reasoning content from the model output.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the model output contains the <think> and </think> tokens.
if
(
self
.
think_start_token
not
in
model_output
or
self
.
think_end_token
not
in
model_output
):
return
None
,
model_output
# Check if the <think> is present in the model output, remove it
# if it is present.
model_output_parts
=
model_output
.
partition
(
self
.
think_start_token
)
model_output
=
model_output_parts
[
2
]
if
model_output_parts
[
1
]
else
model_output_parts
[
0
]
# Check if the model output contains the </think> tokens.
# If the end token is not found, return the model output as is.
if
self
.
think_end_token
not
in
model_output
:
return
None
,
model_output
# Extract reasoning content from the model output.
reasoning_content
,
_
,
content
=
model_output
.
partition
(
self
.
think_end_token
)
final_content
=
content
or
None
return
reasoning_content
,
final_content
\ No newline at end of file
vllm/worker/worker.py
View file @
751c492c
...
@@ -76,6 +76,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -76,6 +76,7 @@ class Worker(LocalOrDistributedWorkerBase):
"mlp_speculator"
,
"mlp_speculator"
,
"eagle"
,
"eagle"
,
"deepseek_mtp"
,
"deepseek_mtp"
,
"glm4_moe_mtp"
,
"mimo_mtp"
))
\
"mimo_mtp"
))
\
else
{
"return_hidden_states"
:
True
}
else
{
"return_hidden_states"
:
True
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment