Commit 2c35b6cd authored by zhuwenwen's avatar zhuwenwen
Browse files

Fix reasoning_content for chat_template include <think> tag as input

parent 21833462
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
......@@ -44,6 +45,19 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
previous_text: str,
......@@ -67,6 +81,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
......@@ -85,7 +101,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
......@@ -101,35 +116,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
) -> tuple[Optional[str], Optional[str]]:
# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
return None, model_output
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]
final_output = model_output[end_index:]
if len(model_output) == 0:
if len(final_output) == 0:
return reasoning_content, None
return reasoning_content, model_output
return reasoning_content, final_output
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment