Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
782505ed
Unverified
Commit
782505ed
authored
Oct 13, 2025
by
CSWYF3634076
Committed by
GitHub
Oct 13, 2025
Browse files
[Model] Add reasoning_parser and tool_parser for Ernie45 thinking (#25027)
Signed-off-by:
wangyafeng
<
wangyafeng@baidu.com
>
parent
98f30b8c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
870 additions
and
0 deletions
+870
-0
docs/features/reasoning_outputs.md
docs/features/reasoning_outputs.md
+2
-0
tests/reasoning/test_ernie45_reasoning_parser.py
tests/reasoning/test_ernie45_reasoning_parser.py
+124
-0
tests/tool_use/test_ernie45_moe_tool_parser.py
tests/tool_use/test_ernie45_moe_tool_parser.py
+359
-0
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+2
-0
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
+212
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+2
-0
vllm/reasoning/ernie45_reasoning_parser.py
vllm/reasoning/ernie45_reasoning_parser.py
+169
-0
No files found.
docs/features/reasoning_outputs.md
View file @
782505ed
...
@@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models:
...
@@ -11,6 +11,8 @@ vLLM currently supports the following reasoning models:
| Model Series | Parser Name | Structured Output Support | Tool Calling |
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|--------------|-------------|------------------|-------------|
|--------------|-------------|------------------|-------------|
|
[
DeepSeek R1 series
](
https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d
)
|
`deepseek_r1`
|
`json`
,
`regex`
| ❌ |
|
[
DeepSeek R1 series
](
https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d
)
|
`deepseek_r1`
|
`json`
,
`regex`
| ❌ |
|
[
ERNIE-4.5-VL series
](
https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT
)
|
`ernie45`
|
`json`
,
`regex`
| ❌ |
|
[
ERNIE-4.5-21B-A3B-Thinking
](
https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking
)
|
`ernie45`
|
`json`
,
`regex`
| ✅ |
|
[
QwQ-32B
](
https://huggingface.co/Qwen/QwQ-32B
)
|
`deepseek_r1`
|
`json`
,
`regex`
| ✅ |
|
[
QwQ-32B
](
https://huggingface.co/Qwen/QwQ-32B
)
|
`deepseek_r1`
|
`json`
,
`regex`
| ✅ |
|
[
IBM Granite 3.2 language models
](
https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a
)
|
`granite`
| ❌ | ❌ |
|
[
IBM Granite 3.2 language models
](
https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a
)
|
`granite`
| ❌ | ❌ |
|
[
Qwen3 series
](
https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f
)
|
`qwen3`
|
`json`
,
`regex`
| ✅ |
|
[
Qwen3 series
](
https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f
)
|
`qwen3`
|
`json`
,
`regex`
| ✅ |
...
...
tests/reasoning/test_ernie45_reasoning_parser.py
0 → 100644
View file @
782505ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
transformers
import
AutoTokenizer
from
tests.reasoning.utils
import
run_reasoning_extraction
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
parser_name
=
"ernie45"
REASONING_MODEL_NAME
=
"baidu/ERNIE-4.5-21B-A3B-Thinking"
@
pytest
.
fixture
(
scope
=
"module"
)
def
ernie45_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
REASONING_MODEL_NAME
)
# 带 </think>,非stream
WITH_THINK
=
{
"output"
:
"abc</think>def"
,
"reasoning_content"
:
"abc"
,
"content"
:
"def"
,
}
# 带 </think>,stream
WITH_THINK_STREAM
=
{
"output"
:
"abc</think>def"
,
"reasoning_content"
:
"abc"
,
"content"
:
"def"
,
}
# without </think>, all is reasoning_content
WITHOUT_THINK
=
{
"output"
:
"abc"
,
"reasoning_content"
:
"abc"
,
"content"
:
None
,
}
# without </think>, all is reasoning_content
WITHOUT_THINK_STREAM
=
{
"output"
:
"abc"
,
"reasoning_content"
:
"abc"
,
"content"
:
None
,
}
COMPLETE_REASONING
=
{
"output"
:
"abc</think>"
,
"reasoning_content"
:
"abc"
,
"content"
:
None
,
}
MULTILINE_REASONING
=
{
"output"
:
"abc
\n
ABC</think>def
\n
DEF"
,
"reasoning_content"
:
"abc
\n
ABC"
,
"content"
:
"def
\n
DEF"
,
}
TEST_CASES
=
[
pytest
.
param
(
False
,
WITH_THINK
,
id
=
"with_think"
,
),
pytest
.
param
(
True
,
WITH_THINK_STREAM
,
id
=
"with_think_stream"
,
),
pytest
.
param
(
False
,
WITHOUT_THINK
,
id
=
"without_think"
,
),
pytest
.
param
(
True
,
WITHOUT_THINK_STREAM
,
id
=
"without_think_stream"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
id
=
"complete_reasoning"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
id
=
"complete_reasoning_stream"
,
),
pytest
.
param
(
False
,
MULTILINE_REASONING
,
id
=
"multiline_reasoning"
,
),
pytest
.
param
(
True
,
MULTILINE_REASONING
,
id
=
"multiline_reasoning_stream"
,
),
]
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict"
,
TEST_CASES
)
def
test_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
ernie45_tokenizer
,
):
output
=
ernie45_tokenizer
.
tokenize
(
param_dict
[
"output"
])
output_tokens
:
list
[
str
]
=
[]
for
token
in
output
:
one_token
=
ernie45_tokenizer
.
convert_tokens_to_string
([
token
])
if
one_token
:
output_tokens
.
append
(
one_token
)
parser
:
ReasoningParser
=
ReasoningParserManager
.
get_reasoning_parser
(
parser_name
)(
ernie45_tokenizer
)
reasoning
,
content
=
run_reasoning_extraction
(
parser
,
output_tokens
,
streaming
=
streaming
)
print
()
assert
reasoning
==
param_dict
[
"reasoning_content"
]
assert
content
==
param_dict
[
"content"
]
tests/tool_use/test_ernie45_moe_tool_parser.py
0 → 100644
View file @
782505ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import
json
from
collections.abc
import
Generator
import
pytest
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.tool_parsers
import
Ernie45ToolParser
from
vllm.transformers_utils.detokenizer_utils
import
detokenize_incrementally
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
get_tokenizer
# Use a common model that is likely to be available
MODEL
=
"baidu/ERNIE-4.5-21B-A3B-Thinking"
@
pytest
.
fixture
(
scope
=
"module"
)
def
ernie45_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
MODEL
,
trust_remote_code
=
True
)
@
pytest
.
fixture
def
ernie45_tool_parser
(
ernie45_tokenizer
):
return
Ernie45ToolParser
(
ernie45_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
],
expected_tool_calls
:
list
[
ToolCall
]
):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
expected_tool_calls
):
assert
isinstance
(
actual_tool_call
.
id
,
str
)
assert
len
(
actual_tool_call
.
id
)
>
0
assert
actual_tool_call
.
type
==
"function"
assert
actual_tool_call
.
function
.
name
==
expected_tool_call
.
function
.
name
# Compare arguments as JSON objects to handle formatting differences
actual_args
=
json
.
loads
(
actual_tool_call
.
function
.
arguments
)
expected_args
=
json
.
loads
(
expected_tool_call
.
function
.
arguments
)
assert
actual_args
==
expected_args
def
test_extract_tool_calls_no_tools
(
ernie45_tool_parser
):
model_output
=
"This is a test"
extracted_tool_calls
=
ernie45_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_call"
,
"multiple_tool_calls"
,
"tool_call_with_content_before"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
)
],
None
,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_temperature_unit"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Guangzhou"
,
"unit"
:
"c"
,
}
),
)
),
],
None
,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_temperature_unit"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Guangzhou"
,
"unit"
:
"c"
,
}
),
)
),
],
"I need to call two tools to handle these two issues separately.
\n
</think>"
,
),
],
)
def
test_extract_tool_calls
(
ernie45_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
ernie45_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
assert
extracted_tool_calls
.
content
==
expected_content
def
stream_delta_message_generator
(
ernie45_tool_parser
:
Ernie45ToolParser
,
ernie45_tokenizer
:
AnyTokenizer
,
model_output
:
str
,
request
:
ChatCompletionRequest
|
None
=
None
,
)
->
Generator
[
DeltaMessage
,
None
,
None
]:
all_token_ids
=
ernie45_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
previous_text
=
""
previous_tokens
=
None
prefix_offset
=
0
read_offset
=
0
for
i
,
delta_token
in
enumerate
(
all_token_ids
):
delta_token_ids
=
[
delta_token
]
previous_token_ids
=
all_token_ids
[:
i
]
current_token_ids
=
all_token_ids
[:
i
+
1
]
(
new_tokens
,
delta_text
,
new_prefix_offset
,
new_read_offset
)
=
(
detokenize_incrementally
(
tokenizer
=
ernie45_tokenizer
,
all_input_ids
=
current_token_ids
,
prev_tokens
=
previous_tokens
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
False
,
spaces_between_special_tokens
=
True
,
)
)
current_text
=
previous_text
+
delta_text
delta_message
=
ernie45_tool_parser
.
extract_tool_calls_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
request
=
request
,
)
if
delta_message
:
yield
delta_message
previous_text
=
current_text
previous_tokens
=
(
previous_tokens
+
new_tokens
if
previous_tokens
else
new_tokens
)
prefix_offset
=
new_prefix_offset
read_offset
=
new_read_offset
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_call"
,
"multiple_tool_calls"
,
"tool_call_with_content_before"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
)
],
None
,
),
(
"""<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_temperature_unit"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Guangzhou"
,
"unit"
:
"c"
,
}
),
)
),
],
None
,
),
(
"""I need to call two tools to handle these two issues separately.
</think>
<tool_call>
{"name": "get_current_temperature", "arguments": {"location": "Beijing"}}
</tool_call>
<tool_call>
{"name": "get_temperature_unit", "arguments": {"location": "Guangzhou", "unit": "c"}}
</tool_call>
"""
,
[
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_temperature"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Beijing"
,
}
),
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_temperature_unit"
,
arguments
=
json
.
dumps
(
{
"location"
:
"Guangzhou"
,
"unit"
:
"c"
,
}
),
)
),
],
"I need to call two tools to handle these two issues separately.
\n
</think>"
,
),
],
)
def
test_extract_tool_calls_streaming_incremental
(
ernie45_tool_parser
,
ernie45_tokenizer
,
model_output
,
expected_tool_calls
,
expected_content
,
):
"""Verify the Ernie45 Parser streaming behavior by verifying each chunk is as expected."""
# noqa: E501
request
=
ChatCompletionRequest
(
model
=
MODEL
,
messages
=
[],
tools
=
[])
tool_calls_dict
=
{}
for
delta_message
in
stream_delta_message_generator
(
ernie45_tool_parser
,
ernie45_tokenizer
,
model_output
,
request
):
if
(
delta_message
.
role
is
None
and
delta_message
.
content
is
None
and
delta_message
.
reasoning_content
is
None
and
len
(
delta_message
.
tool_calls
)
==
0
):
continue
tool_calls
=
delta_message
.
tool_calls
for
tool_call_chunk
in
tool_calls
:
index
=
tool_call_chunk
.
index
if
index
not
in
tool_calls_dict
:
if
tool_call_chunk
.
function
.
arguments
is
None
:
tool_call_chunk
.
function
.
arguments
=
""
tool_calls_dict
[
index
]
=
tool_call_chunk
else
:
tool_calls_dict
[
index
].
function
.
arguments
+=
tool_call_chunk
.
function
.
arguments
actual_tool_calls
=
list
(
tool_calls_dict
.
values
())
assert
len
(
actual_tool_calls
)
>
0
# check tool call format
assert_tool_calls
(
actual_tool_calls
,
expected_tool_calls
)
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
782505ed
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
.abstract_tool_parser
import
ToolParser
,
ToolParserManager
from
.abstract_tool_parser
import
ToolParser
,
ToolParserManager
from
.deepseekv3_tool_parser
import
DeepSeekV3ToolParser
from
.deepseekv3_tool_parser
import
DeepSeekV3ToolParser
from
.deepseekv31_tool_parser
import
DeepSeekV31ToolParser
from
.deepseekv31_tool_parser
import
DeepSeekV31ToolParser
from
.ernie45_tool_parser
import
Ernie45ToolParser
from
.glm4_moe_tool_parser
import
Glm4MoeModelToolParser
from
.glm4_moe_tool_parser
import
Glm4MoeModelToolParser
from
.granite_20b_fc_tool_parser
import
Granite20bFCToolParser
from
.granite_20b_fc_tool_parser
import
Granite20bFCToolParser
from
.granite_tool_parser
import
GraniteToolParser
from
.granite_tool_parser
import
GraniteToolParser
...
@@ -42,6 +43,7 @@ __all__ = [
...
@@ -42,6 +43,7 @@ __all__ = [
"Phi4MiniJsonToolParser"
,
"Phi4MiniJsonToolParser"
,
"DeepSeekV3ToolParser"
,
"DeepSeekV3ToolParser"
,
"DeepSeekV31ToolParser"
,
"DeepSeekV31ToolParser"
,
"Ernie45ToolParser"
,
"xLAMToolParser"
,
"xLAMToolParser"
,
"MinimaxToolParser"
,
"MinimaxToolParser"
,
"KimiK2ToolParser"
,
"KimiK2ToolParser"
,
...
...
vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py
0 → 100644
View file @
782505ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
Sequence
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
,
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
(
"ernie45"
)
class
Ernie45ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
"""
Ernie thinking model format:
abc
\n
</think>
\n\n\n
<tool_call>
\n
def
\n
</tool_call>
\n
"""
super
().
__init__
(
tokenizer
)
self
.
current_tool_name_sent
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
=
-
1
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
self
.
think_end_token
=
"</think>"
self
.
response_start_token
:
str
=
"<response>"
self
.
response_end_token
:
str
=
"</response>"
self
.
tool_call_start_token
=
"<tool_call>"
self
.
tool_call_end_token
=
"</tool_call>"
self
.
tool_calls_start_token
=
self
.
tool_call_start_token
self
.
newline_token
:
str
=
"<0x0A>"
self
.
tool_call_regex
=
re
.
compile
(
r
"<tool_call>\s*(?P<json>\{.*?\})\s*</tool_call>"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
self
.
response_start_token_id
=
self
.
vocab
.
get
(
self
.
response_start_token
)
self
.
response_end_token_id
=
self
.
vocab
.
get
(
self
.
response_end_token
)
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
self
.
newline_token_id
=
self
.
vocab
.
get
(
self
.
newline_token
)
self
.
parser_token_ids
=
[
self
.
think_end_token_id
,
self
.
response_start_token_id
,
self
.
response_end_token_id
,
]
self
.
_buffer
=
""
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
# sanity check; avoid unnecessary processing
if
self
.
tool_calls_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
else
:
try
:
tool_call_json_list
=
self
.
tool_call_regex
.
findall
(
model_output
)
tool_calls
=
[]
for
tool_call_json
in
tool_call_json_list
:
tool_call_dict
=
json
.
loads
(
tool_call_json
)
args_str
=
json
.
dumps
(
tool_call_dict
.
get
(
"arguments"
,
{}),
ensure_ascii
=
False
)
tool_calls
.
append
(
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
tool_call_dict
.
get
(
"name"
,
""
),
arguments
=
args_str
,
),
)
)
content
=
model_output
[
:
model_output
.
find
(
self
.
tool_calls_start_token
)
].
rstrip
(
"
\n
"
)
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
,
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
self
.
_buffer
+=
delta_text
cur_text
=
self
.
_buffer
start_idx
=
cur_text
.
find
(
self
.
tool_call_start_token
)
if
start_idx
==
-
1
:
self
.
_buffer
=
""
# At least one toolcall has been completed
if
self
.
current_tool_id
>
0
:
cur_text
=
""
if
self
.
current_tool_id
==
-
1
and
all
(
token_id
==
self
.
newline_token_id
for
token_id
in
previous_token_ids
):
cur_text
=
cur_text
.
strip
(
"
\n
"
)
# handle <response> </response> when tool_call is not triggered
# cur_text === delta_text
content
=
cur_text
if
self
.
response_start_token_id
in
delta_token_ids
:
content
=
content
.
lstrip
(
"
\n
"
)
response_start_idx
=
content
.
find
(
self
.
response_start_token
)
content
=
content
[
response_start_idx
+
len
(
self
.
response_start_token
)
:]
# if have </response>, remove it
response_end_idx
=
content
.
rfind
(
self
.
response_end_token
)
if
response_end_idx
!=
-
1
:
content
=
content
[:
response_end_idx
]
elif
self
.
response_end_token_id
in
delta_token_ids
:
response_end_idx
=
content
.
rfind
(
self
.
response_end_token
)
content
=
content
[:
response_end_idx
]
# remove \n after </think> or <response> or </response>
if
(
len
(
previous_token_ids
)
>
0
and
previous_token_ids
[
-
1
]
in
self
.
parser_token_ids
)
and
(
len
(
delta_token_ids
)
>
0
and
delta_token_ids
[
0
]
==
self
.
newline_token_id
):
content
=
content
.
lstrip
(
"
\n
"
)
return
DeltaMessage
(
content
=
content
if
content
else
None
)
logger
.
debug
(
"cur_text = %s"
,
cur_text
)
end_idx
=
cur_text
.
find
(
self
.
tool_call_end_token
)
if
end_idx
!=
-
1
:
if
self
.
current_tool_id
==
-
1
:
self
.
current_tool_id
=
0
self
.
prev_tool_call_arr
=
[]
self
.
streamed_args_for_tool
=
[]
while
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_id
:
self
.
prev_tool_call_arr
.
append
({})
while
len
(
self
.
streamed_args_for_tool
)
<=
self
.
current_tool_id
:
self
.
streamed_args_for_tool
.
append
(
""
)
extracted_tool_calls
=
self
.
extract_tool_calls
(
cur_text
[:
end_idx
+
len
(
self
.
tool_call_end_token
)],
request
)
if
len
(
extracted_tool_calls
.
tool_calls
)
==
0
:
logger
.
warning
(
"Failed to extract any tool calls."
)
return
None
tool_call
=
extracted_tool_calls
.
tool_calls
[
0
]
self
.
prev_tool_call_arr
[
self
.
current_tool_id
]
=
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
json
.
loads
(
tool_call
.
function
.
arguments
),
}
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
=
(
tool_call
.
function
.
arguments
)
delta
=
DeltaMessage
(
content
=
extracted_tool_calls
.
content
,
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
id
=
tool_call
.
id
,
type
=
tool_call
.
type
,
function
=
DeltaFunctionCall
(
name
=
tool_call
.
function
.
name
,
arguments
=
tool_call
.
function
.
arguments
,
),
)
],
)
self
.
current_tool_id
+=
1
self
.
_buffer
=
cur_text
[
end_idx
+
len
(
self
.
tool_call_end_token
)
:]
return
delta
self
.
_buffer
=
cur_text
[
start_idx
:]
content
=
cur_text
[:
start_idx
].
rstrip
(
"
\n
"
)
return
DeltaMessage
(
content
=
content
if
content
else
None
)
vllm/reasoning/__init__.py
View file @
782505ed
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.basic_parsers
import
BaseThinkingReasoningParser
from
.basic_parsers
import
BaseThinkingReasoningParser
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.ernie45_reasoning_parser
import
Ernie45ReasoningParser
from
.glm4_moe_reasoning_parser
import
Glm4MoeModelReasoningParser
from
.glm4_moe_reasoning_parser
import
Glm4MoeModelReasoningParser
from
.gptoss_reasoning_parser
import
GptOssReasoningParser
from
.gptoss_reasoning_parser
import
GptOssReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
...
@@ -19,6 +20,7 @@ __all__ = [
...
@@ -19,6 +20,7 @@ __all__ = [
"BaseThinkingReasoningParser"
,
"BaseThinkingReasoningParser"
,
"ReasoningParserManager"
,
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
,
"DeepSeekR1ReasoningParser"
,
"Ernie45ReasoningParser"
,
"GraniteReasoningParser"
,
"GraniteReasoningParser"
,
"HunyuanA13BReasoningParser"
,
"HunyuanA13BReasoningParser"
,
"Qwen3ReasoningParser"
,
"Qwen3ReasoningParser"
,
...
...
vllm/reasoning/ernie45_reasoning_parser.py
0 → 100644
View file @
782505ed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
,
DeltaMessage
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.reasoning.basic_parsers
import
BaseThinkingReasoningParser
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"ernie45"
)
class
Ernie45ReasoningParser
(
BaseThinkingReasoningParser
):
"""
Reasoning parser for Ernie45 thinking model.
The Ernie45 thinking model ouput format is
abc
\n
</think>
\n\n
<response>
\n
def
\n
</response>
\n
or abc
\n
</think>
\n
def
"""
response_start_token
:
str
=
"<response>"
response_end_token
:
str
=
"</response>"
newline_token
:
str
=
"<0x0A>"
@
property
def
start_token
(
self
)
->
str
:
"""The token that starts reasoning content."""
return
"<think>"
@
property
def
end_token
(
self
)
->
str
:
"""The token that ends reasoning content."""
return
"</think>"
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
self
.
start_token_id
=
self
.
vocab
.
get
(
self
.
start_token
)
self
.
end_token_id
=
self
.
vocab
.
get
(
self
.
end_token
)
self
.
response_start_token_id
=
self
.
vocab
.
get
(
self
.
response_start_token
)
self
.
response_end_token_id
=
self
.
vocab
.
get
(
self
.
response_end_token
)
self
.
newline_token_id
=
self
.
vocab
.
get
(
self
.
newline_token
)
self
.
parser_token_ids
=
[
self
.
end_token_id
,
self
.
response_end_token_id
]
if
self
.
start_token_id
is
None
or
self
.
end_token_id
is
None
:
raise
RuntimeError
(
"Ernie45 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
The Ernie45 thinking model ouput format is
abc
\n
</think>
\n\n
<response>
\n
def
\n
</response>
\n
or abc
\n
</think>
\n
def
- 'abc' goes to reasoning_content
- 'def' goes to content
"""
# Skip single special tokens
if
len
(
delta_token_ids
)
==
1
and
(
delta_token_ids
[
0
]
in
[
self
.
start_token_id
,
self
.
end_token_id
,
self
.
response_start_token_id
,
self
.
response_end_token_id
,
]
):
return
None
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
if
self
.
end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
think_end_index
=
delta_text
.
find
(
self
.
end_token
)
reasoning_content
=
delta_text
[:
think_end_index
]
content
=
delta_text
[
think_end_index
+
len
(
self
.
end_token
)
:]
content
=
content
.
lstrip
(
"
\n
"
)
response_start_idx
=
content
.
find
(
self
.
response_start_token
)
response_end_idx
=
content
.
rfind
(
self
.
response_end_token
)
if
response_start_idx
!=
-
1
:
content
=
content
[
response_start_idx
+
len
(
self
.
response_start_token
)
:]
if
response_end_idx
!=
-
1
:
content
=
content
[:
response_end_idx
]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
,
)
elif
self
.
end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
content
=
delta_text
if
self
.
response_start_token_id
in
delta_token_ids
:
content
=
content
.
lstrip
(
"
\n
"
)
response_start_idx
=
content
.
find
(
self
.
response_start_token
)
content
=
content
[
response_start_idx
+
len
(
self
.
response_start_token
)
:]
# if have </response>, remove it
response_end_idx
=
content
.
rfind
(
self
.
response_end_token
)
if
response_end_idx
!=
-
1
:
content
=
content
[:
response_end_idx
]
elif
self
.
response_end_token_id
in
delta_token_ids
:
response_end_idx
=
content
.
rfind
(
self
.
response_end_token
)
content
=
content
[:
response_end_idx
]
# remove \n after </think> or </response>
if
previous_token_ids
[
-
1
]
in
self
.
parser_token_ids
and
(
len
(
delta_token_ids
)
>
0
and
delta_token_ids
[
0
]
==
self
.
newline_token_id
):
content
=
content
.
lstrip
(
"
\n
"
)
# remove \n after </think>\n
if
(
len
(
previous_token_ids
)
>
1
and
previous_token_ids
[
-
2
]
==
self
.
end_token_id
)
and
(
len
(
delta_token_ids
)
>
0
and
delta_token_ids
[
0
]
==
self
.
newline_token_id
):
content
=
content
.
lstrip
(
"
\n
"
)
return
DeltaMessage
(
content
=
content
if
content
else
None
)
else
:
# no </think> in previous or delta, reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
str
|
None
,
str
|
None
]:
"""
Extract reasoning content from the model output.
The Ernie45 thinking model ouput format is
abc
\n
</think>
\n\n\n
<response>
\n
def
\n
</response>
\n
or abc
\n
</think>
\n
def
- 'abc' goes to reasoning_content
- 'def' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
reasoning_content
,
content
=
super
().
extract_reasoning_content
(
model_output
,
request
)
if
content
:
start_idx
=
content
.
find
(
self
.
response_start_token
)
end_idx
=
content
.
rfind
(
self
.
response_end_token
)
# Simultaneously existing and in the correct order
if
start_idx
!=
-
1
and
end_idx
!=
-
1
and
start_idx
<
end_idx
:
content
=
content
[
start_idx
+
len
(
self
.
response_start_token
)
:
end_idx
]
final_content
=
content
or
None
return
reasoning_content
,
final_content
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment