Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b140416a
Unverified
Commit
b140416a
authored
Jul 11, 2025
by
Asher
Committed by
GitHub
Jul 10, 2025
Browse files
[Model] Add reason parser for Hunyuan A13B Model. (#20625)
Signed-off-by:
Asher Zhang
<
asherszhang@tencent.com
>
parent
5b8366b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
402 additions
and
0 deletions
+402
-0
tests/reasoning/test_hunyuan_reasoning_parser.py
tests/reasoning/test_hunyuan_reasoning_parser.py
+162
-0
vllm/reasoning/__init__.py
vllm/reasoning/__init__.py
+2
-0
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
+238
-0
No files found.
tests/reasoning/test_hunyuan_reasoning_parser.py
0 → 100644
View file @
b140416a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
transformers
import
AutoTokenizer
from
tests.reasoning.utils
import
run_reasoning_extraction
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
parser_name
=
"hunyuan_a13b"
START_REASONING
=
"<think>
\n
"
START_RESPONSE
=
"
\n
</think>
\n
<answer>
\n
"
END_RESPONSE
=
"
\n
</answer>"
NO_REASONING_QUICK_THROUGHT
=
{
"output"
:
f
"
{
START_REASONING
}{
START_RESPONSE
}
This is the rest
{
END_RESPONSE
}
"
,
#noqa: E501
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
}
SIMPLE_REASONING
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
This is the rest
{
END_RESPONSE
}
"
,
#noqa: E501
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
NO_REASONING
=
{
"output"
:
"This is content"
,
"reasoning_content"
:
None
,
"content"
:
"This is content"
,
}
MULTIPLE_LINES
=
{
"output"
:
f
"
{
START_REASONING
}
This
\n
That
{
START_RESPONSE
}
This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
REASONING_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
This is the rest"
,
#noqa: E501
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This is a reasoning section
{
START_RESPONSE
}
"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
MULTIPLE_LINES_WITH_THINK
=
{
"output"
:
f
"
{
START_REASONING
}
This
\n
That
{
START_RESPONSE
}
This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
TEST_CASES
=
[
pytest
.
param
(
False
,
SIMPLE_REASONING
,
id
=
"simple_reasoning"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
id
=
"complete_reasoning"
,
),
pytest
.
param
(
False
,
NO_REASONING
,
id
=
"no_reasoning"
,
),
pytest
.
param
(
False
,
NO_REASONING_QUICK_THROUGHT
,
id
=
"no_reasoning_quick"
),
pytest
.
param
(
False
,
MULTIPLE_LINES
,
id
=
"multiple_lines"
,
),
pytest
.
param
(
False
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think"
,
),
pytest
.
param
(
True
,
SIMPLE_REASONING
,
id
=
"simple_reasoning_streaming"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
id
=
"complete_reasoning_streaming"
,
),
pytest
.
param
(
True
,
NO_REASONING
,
id
=
"no_reasoning_streaming"
,
),
pytest
.
param
(
True
,
NO_REASONING_QUICK_THROUGHT
,
id
=
"no_reasoning_quick_stream"
),
pytest
.
param
(
True
,
MULTIPLE_LINES
,
id
=
"multiple_lines_streaming"
,
),
pytest
.
param
(
True
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think_streaming"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think_streaming"
,
),
pytest
.
param
(
True
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think_streaming"
,
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"tencent/Hunyuan-A13B-Instruct"
,
trust_remote_code
=
True
)
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict"
,
TEST_CASES
)
def
test_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
):
output
=
tokenizer
.
tokenize
(
param_dict
[
"output"
])
# decode everything to tokens
output_tokens
:
list
[
str
]
=
[
tokenizer
.
convert_tokens_to_string
([
token
])
for
token
in
output
]
parser
:
ReasoningParser
=
ReasoningParserManager
.
get_reasoning_parser
(
parser_name
)(
tokenizer
)
reasoning
,
content
=
run_reasoning_extraction
(
parser
,
output_tokens
,
streaming
=
streaming
)
assert
reasoning
==
param_dict
[
"reasoning_content"
]
assert
content
==
param_dict
[
"content"
]
vllm/reasoning/__init__.py
View file @
b140416a
...
...
@@ -4,6 +4,7 @@
from
.abs_reasoning_parsers
import
ReasoningParser
,
ReasoningParserManager
from
.deepseek_r1_reasoning_parser
import
DeepSeekR1ReasoningParser
from
.granite_reasoning_parser
import
GraniteReasoningParser
from
.hunyuan_a13b_reasoning_parser
import
HunyuanA13BReasoningParser
from
.qwen3_reasoning_parser
import
Qwen3ReasoningParser
__all__
=
[
...
...
@@ -11,5 +12,6 @@ __all__ = [
"ReasoningParserManager"
,
"DeepSeekR1ReasoningParser"
,
"GraniteReasoningParser"
,
"HunyuanA13BReasoningParser"
,
"Qwen3ReasoningParser"
,
]
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
0 → 100644
View file @
b140416a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
)
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParser
,
ReasoningParserManager
logger
=
init_logger
(
__name__
)
@
ReasoningParserManager
.
register_module
(
"hunyuan_a13b"
)
class
HunyuanA13BReasoningParser
(
ReasoningParser
):
"""
Reasoning parser for Hunyuan A13B Model
HunyuanReasoningParser
This class implements a reasoning parser specifically designed
for the Hunyuan A13B Model. It is responsible for parsing and
extracting structured reasoning and answer segments from model
outputs that follow a specific pattern.
Key Features:
- For non-stream output , Recognizes and extracts reasoning ("think")
and answer ("answer") sections from text using regular expressions.
- For stream process, it require a token id sequences to change the
reasoning state and other state so it maintains internal state to
manage parsing across multiple token.
think start: "<think>
\n
": [14023, 771, 397]
think ends: "
\n
</think>
\n
<answer>
\n
": [198, 524, 27963, 397, 27, 9399, 397]
response ends: "
\n
</answer>": [524, 9399, 29]
"""
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
self
.
think_start_expr
=
r
"<think>\n"
self
.
think_end_expr
=
r
"\n</think>\n"
self
.
response_start_expr
=
r
"\n</think>\n<answer>\n"
self
.
response_end_expr
=
r
"\n</answer>"
self
.
full_match_reasoning_regex
=
re
.
compile
(
rf
"(?:
{
self
.
think_start_expr
}
(.*?)
{
self
.
response_start_expr
}
)?(.*?)
{
self
.
response_end_expr
}
"
,
re
.
DOTALL
)
self
.
half_match_reasoning_regex
=
re
.
compile
(
rf
"
{
self
.
think_start_expr
}
(.*?)
{
self
.
response_start_expr
}
(.*)"
,
re
.
DOTALL
)
self
.
think_start_ids
=
[
14023
,
771
,
397
]
self
.
think_start_ids_fast
=
[
14023
,
771
,
1363
]
self
.
response_start_ids
=
[
198
,
524
,
27963
,
397
,
27
,
9399
,
397
]
self
.
response_start_ids_fast
=
[
524
,
27963
,
397
,
27
,
9399
,
397
]
self
.
response_end_ids
=
[
198
,
524
,
9399
,
29
]
self
.
fast_think_ids
=
[
14023
,
771
,
1363
,
524
,
27963
,
397
,
27
,
9399
,
397
]
# when state change, send out all the buffered text in last state
self
.
buffered_text
=
[]
self
.
buffered_ids
=
[]
self
.
current_state
=
"reasoning"
self
.
all_states
=
[
"reasoning"
,
"response"
]
self
.
current_state
=
"idle"
self
.
expected_sequence
=
self
.
think_start_ids
# this sequence only for the think start, it has two way to start.
self
.
expected_sequence_side
=
self
.
think_start_ids_fast
self
.
sequence_index
=
0
self
.
token_buffer
=
[]
self
.
text_buffer
=
""
def
is_reasoning_end
(
self
,
input_ids
:
list
[
int
])
->
bool
:
return
self
.
current_state
==
"response"
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
tuple
[
Optional
[
str
],
Optional
[
str
]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionRequest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match
=
self
.
full_match_reasoning_regex
.
findall
(
model_output
)
if
re_match
:
reasoning_content
,
response_content
=
re_match
[
0
]
if
len
(
reasoning_content
)
==
0
:
reasoning_content
=
None
if
len
(
response_content
)
==
0
:
response_content
=
None
return
reasoning_content
,
response_content
fallback_regex
=
self
.
half_match_reasoning_regex
fallback_match
=
fallback_regex
.
findall
(
model_output
)
if
fallback_match
:
reasoning_content
,
response_content
=
fallback_match
[
0
]
if
response_content
.
endswith
(
self
.
response_end_expr
):
response_content
=
response_content
[:
-
len
(
self
.
response_end_expr
)]
if
len
(
reasoning_content
)
==
0
:
reasoning_content
=
None
if
len
(
response_content
)
==
0
:
response_content
=
None
return
reasoning_content
,
response_content
return
None
,
model_output
def
_is_strict_increasing_subsequence
(
self
,
subsequence
:
Sequence
[
int
],
sequence
:
Sequence
[
int
])
->
bool
:
if
not
subsequence
:
return
False
sub_idx
=
0
for
num
in
sequence
:
if
sub_idx
<
len
(
subsequence
)
and
num
==
subsequence
[
sub_idx
]:
sub_idx
+=
1
return
sub_idx
==
len
(
subsequence
)
def
extract_reasoning_content_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
)
->
Union
[
DeltaMessage
,
None
]:
"""Extract content using token ID sequence state machine"""
# Define sequences
think_start_sequence
=
self
.
think_start_ids
response_start_sequence
=
self
.
response_start_ids
response_end_sequence
=
self
.
response_end_ids
assert
(
len
(
delta_token_ids
)
==
1
)
# Process each token in the delta
token
=
delta_token_ids
[
0
]
def
check_token_with_sequence
(
token
):
if
self
.
current_state
==
"idle"
or
self
.
current_state
==
"think"
:
return
(
token
==
self
.
expected_sequence
[
self
.
sequence_index
]
or
token
==
\
self
.
expected_sequence_side
[
self
.
sequence_index
])
else
:
return
token
==
self
.
expected_sequence
[
self
.
sequence_index
]
def
check_last_token
(
token
):
if
self
.
current_state
==
"idle"
or
self
.
current_state
==
"think"
:
# only return true if it's judge using a side sequence.
if
(
self
.
sequence_index
-
1
<
len
(
self
.
expected_sequence_side
)
and
token
==
self
.
expected_sequence_side
[
self
.
sequence_index
-
1
]):
return
self
.
sequence_index
==
len
(
self
.
expected_sequence_side
)
else
:
return
self
.
sequence_index
==
len
(
self
.
expected_sequence
)
else
:
return
self
.
sequence_index
==
len
(
self
.
expected_sequence
)
# Check if token matches expected sequence
token_in_state_seq
=
check_token_with_sequence
(
token
)
if
token_in_state_seq
:
# Store matching token
self
.
token_buffer
.
append
(
token
)
self
.
text_buffer
+=
delta_text
self
.
sequence_index
+=
1
## state change from idle->think->response->idle
# Check if sequence fully matched
if
check_last_token
(
token
):
# State transition
if
self
.
current_state
==
"idle"
:
self
.
current_state
=
"think"
self
.
expected_sequence
=
response_start_sequence
self
.
expected_sequence_side
=
self
.
response_start_ids_fast
elif
self
.
current_state
==
"think"
:
self
.
current_state
=
"response"
self
.
expected_sequence
=
response_end_sequence
elif
self
.
current_state
==
"response"
:
self
.
current_state
=
"idle"
self
.
expected_sequence
=
think_start_sequence
self
.
expected_sequence_side
=
self
.
think_start_ids_fast
# Reset matching state
self
.
sequence_index
=
0
self
.
token_buffer
=
[]
self
.
text_buffer
=
""
# Do not send content for state transition texts.
else
:
# Sequence broken - handle buffered content
if
self
.
token_buffer
and
len
(
self
.
token_buffer
)
>
0
:
# Send buffered tokens
buffered_content
=
self
.
text_buffer
+
delta_text
# Reset matching state
self
.
sequence_index
=
0
self
.
token_buffer
=
[]
self
.
text_buffer
=
""
# Return content based on current state
if
self
.
current_state
==
"think"
:
return
DeltaMessage
(
reasoning_content
=
buffered_content
,
content
=
None
)
else
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
buffered_content
)
else
:
# No buffered content, send normally
if
self
.
current_state
==
"think"
:
return
DeltaMessage
(
reasoning_content
=
delta_text
,
content
=
None
)
else
:
return
DeltaMessage
(
reasoning_content
=
None
,
content
=
delta_text
)
# If no content to send in this delta
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment