Unverified Commit b8d9e4a3 authored by Tao Hui's avatar Tao Hui Committed by GitHub
Browse files

[Model] Add optional parameter to reasoning parser constructor (#25554)


Signed-off-by: default avatartaohui <taohui3@gmail.com>
Signed-off-by: default avatarTao Hui <taohui3@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 13cc7f53
......@@ -34,7 +34,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output.
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
self.model_tokenizer = tokenizer
@cached_property
......
......@@ -35,8 +35,8 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content."""
raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:
raise ValueError(
......
......@@ -26,8 +26,8 @@ class Glm4MoeModelReasoningParser(ReasoningParser):
from the model's output.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
......
......@@ -24,8 +24,8 @@ class GptOssReasoningParser(ReasoningParser):
is only used for detecting the end of the reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.reasoning_end_token_ids = self.model_tokenizer.encode(
"<|start|>assistant<|channel|>final<|message|>")
......
......@@ -24,8 +24,8 @@ class GraniteReasoningParser(ReasoningParser):
and "Here is my response:" to separate its thinking / response outputs.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# NOTE: There have been some observed occurrences of quantized
# instances of the current models using "Here's" instead of "Here is",
......
......@@ -40,8 +40,8 @@ class HunyuanA13BReasoningParser(ReasoningParser):
response ends: "\n</answer>": [524, 9399, 29]
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_start_expr = r"<think>\n"
self.think_end_expr = r"\n</think>\n"
......
......@@ -21,12 +21,12 @@ class MistralReasoningParser(DeepSeekR1ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
def __init__(self, tokenizer: MistralTokenizer):
def __init__(self, tokenizer: MistralTokenizer, *args, **kwargs):
if not isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"The tokenizer must be an instance of MistralTokenizer.")
ReasoningParser.__init__(self, tokenizer)
ReasoningParser.__init__(self, tokenizer, *args, **kwargs)
if not self.model_tokenizer:
raise ValueError(
......
......@@ -24,8 +24,8 @@ class Step3ReasoningParser(ReasoningParser):
text. This parser extracts all content before </think> as reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment