Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fc6485d2
Unverified
Commit
fc6485d2
authored
Feb 11, 2025
by
Ce Gao
Committed by
GitHub
Feb 11, 2025
Browse files
[Bugfix]: Reasoning output bug according to the chat template change (#13025)
Signed-off-by:
Ce Gao
<
cegao@tensorchord.ai
>
parent
78a141d7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
45 deletions
+129
-45
examples/online_serving/openai_chat_completion_with_reasoning.py
...s/online_serving/openai_chat_completion_with_reasoning.py
+4
-4
tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py
...nai/reasoning_parsers/test_deepseekr1_reasoning_parser.py
+90
-18
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
.../openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
+35
-23
No files found.
examples/online_serving/openai_chat_completion_with_reasoning.py
View file @
fc6485d2
...
...
@@ -36,8 +36,8 @@ response = client.chat.completions.create(model=model, messages=messages)
reasoning_content
=
response
.
choices
[
0
].
message
.
reasoning_content
content
=
response
.
choices
[
0
].
message
.
content
print
(
"reasoning_content:"
,
reasoning_content
)
print
(
"content:"
,
content
)
print
(
"reasoning_content
for Round 1
:"
,
reasoning_content
)
print
(
"content
for Round 1
:"
,
content
)
# Round 2
messages
.
append
({
"role"
:
"assistant"
,
"content"
:
content
})
...
...
@@ -50,5 +50,5 @@ response = client.chat.completions.create(model=model, messages=messages)
reasoning_content
=
response
.
choices
[
0
].
message
.
reasoning_content
content
=
response
.
choices
[
0
].
message
.
content
print
(
"reasoning_content:"
,
reasoning_content
)
print
(
"content:"
,
content
)
print
(
"reasoning_content
for Round 2
:"
,
reasoning_content
)
print
(
"content
for Round 2
:"
,
content
)
tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py
View file @
fc6485d2
...
...
@@ -15,32 +15,62 @@ start_token = "<think>"
end_token
=
"</think>"
SIMPLE_REASONING
=
{
"output"
:
"
<think>
This is a reasoning section</think>This is the rest"
,
"output"
:
"This is a reasoning section</think>This is the rest"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING
=
{
"output"
:
"
<think>
This is a reasoning section</think>"
,
"output"
:
"This is a reasoning section</think>"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
NO_REASONING
=
{
"output"
:
"This is
a reasoning section
"
,
"output"
:
"This is
content
"
,
"reasoning_content"
:
None
,
"content"
:
"This is a reasoning section"
,
"content"
:
"This is content"
,
}
NO_REASONING_STREAMING
=
{
"output"
:
"This is a reasoning section"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
MULTIPLE_LINES
=
{
"output"
:
"
<think>
This
\n
That</think>This is the rest
\n
That"
,
"output"
:
"This
\n
That</think>This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
SHORTEST_REASONING_NO_STREAMING
=
{
"output"
:
"<
think><
/think>This is the rest"
,
"output"
:
"</think>This is the rest"
,
"reasoning_content"
:
""
,
"content"
:
"This is the rest"
,
}
SHORTEST_REASONING
=
{
"output"
:
"<think></think>This is the rest"
,
"output"
:
"</think>This is the rest"
,
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
}
REASONING_WITH_THINK
=
{
"output"
:
"<think>This is a reasoning section</think>This is the rest"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
"This is the rest"
,
}
COMPLETE_REASONING_WITH_THINK
=
{
"output"
:
"<think>This is a reasoning section</think>"
,
"reasoning_content"
:
"This is a reasoning section"
,
"content"
:
None
,
}
MULTIPLE_LINES_WITH_THINK
=
{
"output"
:
"<think>This
\n
That</think>This is the rest
\n
That"
,
"reasoning_content"
:
"This
\n
That"
,
"content"
:
"This is the rest
\n
That"
,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK
=
{
"output"
:
"</think>This is the rest"
,
"reasoning_content"
:
""
,
"content"
:
"This is the rest"
,
}
SHORTEST_REASONING_WITH_THINK
=
{
"output"
:
"</think>This is the rest"
,
"reasoning_content"
:
None
,
"content"
:
"This is the rest"
,
}
...
...
@@ -49,37 +79,37 @@ TEST_CASES = [
pytest
.
param
(
False
,
SIMPLE_REASONING
,
id
=
"simple_
st
rea
m
ing"
,
id
=
"simple_rea
son
ing"
,
),
pytest
.
param
(
True
,
SIMPLE_REASONING
,
id
=
"simple_streaming"
,
id
=
"simple_
reasoning_
streaming"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING
,
id
=
"complete_
st
rea
m
ing"
,
id
=
"complete_rea
son
ing"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING
,
id
=
"complete_streaming"
,
id
=
"complete_
reasoning_
streaming"
,
),
pytest
.
param
(
False
,
NO_REASONING
,
id
=
"no_
st
rea
ming
"
,
id
=
"no_rea
soning_token
"
,
),
pytest
.
param
(
True
,
NO_REASONING
,
id
=
"no_streaming"
,
NO_REASONING
_STREAMING
,
id
=
"no_
reasoning_token_
streaming"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES
,
id
=
"multiple_lines
_streaming
"
,
id
=
"multiple_lines"
,
),
pytest
.
param
(
True
,
...
...
@@ -89,23 +119,65 @@ TEST_CASES = [
pytest
.
param
(
True
,
SHORTEST_REASONING
,
id
=
"shortest
_streaming
"
,
id
=
"shortest"
,
),
pytest
.
param
(
False
,
SHORTEST_REASONING_NO_STREAMING
,
id
=
"shortest_streaming"
,
),
pytest
.
param
(
False
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think"
,
),
pytest
.
param
(
True
,
REASONING_WITH_THINK
,
id
=
"reasoning_with_think_streaming"
,
),
pytest
.
param
(
False
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think"
,
),
pytest
.
param
(
True
,
COMPLETE_REASONING_WITH_THINK
,
id
=
"complete_reasoning_with_think_streaming"
,
),
pytest
.
param
(
False
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think"
,
),
pytest
.
param
(
True
,
MULTIPLE_LINES_WITH_THINK
,
id
=
"multiple_lines_with_think_streaming"
,
),
pytest
.
param
(
False
,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK
,
id
=
"shortest_with_think"
,
),
pytest
.
param
(
True
,
SHORTEST_REASONING_WITH_THINK
,
id
=
"shortest_with_think_streaming"
,
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-125m"
)
tokenizer
.
add_tokens
([
start_token
,
end_token
])
@
pytest
.
mark
.
parametrize
(
"streaming, param_dict"
,
TEST_CASES
)
def
test_reasoning
(
streaming
:
bool
,
param_dict
:
dict
,
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-125m"
)
tokenizer
.
add_tokens
([
start_token
,
end_token
])
output
=
tokenizer
.
tokenize
(
param_dict
[
"output"
])
# decode everything to tokens
output_tokens
:
List
[
str
]
=
[
...
...
vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py
View file @
fc6485d2
...
...
@@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
return
None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if
self
.
think_start_token_id
in
previous_token_ids
:
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in previous, </think> in delta,
...
...
@@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
elif
self
.
think_start_token_id
in
delta_token_ids
:
logger
.
info
(
delta_text
)
if
self
.
think_end_token_id
in
delta_token_ids
:
# <think> in delta, </think> in delta, extract reasoning content
start_index
=
delta_text
.
find
(
self
.
think_start_token
)
...
...
@@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
else
:
# No <think> in previous or delta, reasoning content continues.
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_end_token_id
in
delta_token_ids
:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index
=
delta_text
.
find
(
self
.
think_end_token
)
reasoning_content
=
delta_text
[:
end_index
]
content
=
delta_text
[
end_index
+
len
(
self
.
think_end_token
):]
return
DeltaMessage
(
reasoning_content
=
reasoning_content
,
content
=
content
if
content
else
None
)
elif
self
.
think_end_token_id
in
previous_token_ids
:
# </think> in previous, thinking content ends
return
DeltaMessage
(
content
=
delta_text
)
else
:
# no </think> in previous or delta, reasoning content continues
return
DeltaMessage
(
reasoning_content
=
delta_text
)
def
extract_reasoning_content
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
Tuple
[
Optional
[
str
],
Optional
[
str
]]:
# Check if the model output contains the <think> tokens.
if
(
self
.
think_start_token
not
in
model_output
or
self
.
think_end_token
not
in
model_output
):
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if
self
.
think_end_token
not
in
model_output
:
return
None
,
model_output
else
:
# Add a start token if it's missing to keep compatibility.
if
self
.
think_start_token
not
in
model_output
:
model_output
=
f
"
{
self
.
think_start_token
}{
model_output
}
"
# Use a regex to find the reasoning content
reasoning_content
=
self
.
reasoning_regex
.
findall
(
model_output
)[
0
]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index
=
model_output
.
find
(
self
.
think_start_token
)
if
start_index
!=
-
1
:
end_index
=
start_index
+
len
(
end_index
=
len
(
f
"
{
self
.
think_start_token
}{
reasoning_content
}{
self
.
think_end_token
}
"
)
model_output
=
model_output
[:
start_index
]
+
\
model_output
[
end_index
:]
final_output
=
model_output
[
end_index
:]
if
len
(
mode
l_output
)
==
0
:
if
len
(
fina
l_output
)
==
0
:
return
reasoning_content
,
None
return
reasoning_content
,
mode
l_output
return
reasoning_content
,
fina
l_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment