Unverified Commit 5db91f0a authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Fix some Mistral parser issues (#37209)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent 061980c3
...@@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -310,11 +310,14 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers, trace_headers=trace_headers,
) )
else: else:
reasoning_ended = ( if not request.include_reasoning:
reasoning_parser.is_reasoning_end(prompt_token_ids or []) reasoning_ended = True
if reasoning_parser elif reasoning_parser:
else None reasoning_ended = reasoning_parser.is_reasoning_end(
) prompt_token_ids or []
)
else:
reasoning_ended = None
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_prompt,
......
...@@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode ...@@ -15,8 +15,15 @@ from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import ( from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy, SpecialTokenPolicy,
SpecialTokens, SpecialTokens,
Tokenizer,
)
from mistral_common.tokens.tokenizers.instruct import (
InstructTokenizerBase,
InstructTokenizerV13,
)
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as MistralCommonTokenizer,
) )
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer, SentencePieceTokenizer,
) )
...@@ -26,21 +33,20 @@ from pydantic import ValidationError ...@@ -26,21 +33,20 @@ from pydantic import ValidationError
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
from .protocol import TokenizerLike try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import BatchEncoding from transformers import BatchEncoding
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike): ...@@ -235,15 +241,6 @@ class MistralTokenizer(TokenizerLike):
download_dir: str | None = None, download_dir: str | None = None,
**kwargs, **kwargs,
) -> "MistralTokenizer": ) -> "MistralTokenizer":
try:
# Transformers v5
from transformers.tokenization_mistral_common import MistralCommonBackend
except ImportError:
# Transformers v4
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as MistralCommonBackend,
)
tokenizer = MistralCommonBackend.from_pretrained( tokenizer = MistralCommonBackend.from_pretrained(
path_or_repo_id, path_or_repo_id,
*args, *args,
...@@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike): ...@@ -255,13 +252,13 @@ class MistralTokenizer(TokenizerLike):
return cls(tokenizer) return cls(tokenizer)
def __init__(self, tokenizer: "MistralCommonBackend") -> None: def __init__(self, tokenizer: MistralCommonBackend) -> None:
super().__init__() super().__init__()
self.transformers_tokenizer = tokenizer self.transformers_tokenizer: MistralCommonBackend = tokenizer
self.mistral = tokenizer.tokenizer self.mistral: MistralCommonTokenizer = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer self.instruct: InstructTokenizerBase = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer self.tokenizer: Tokenizer = self.instruct.tokenizer
mode = self.mistral._chat_completion_request_validator._mode mode = self.mistral._chat_completion_request_validator._mode
if mode != ValidationMode.test: if mode != ValidationMode.test:
...@@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike): ...@@ -483,7 +480,11 @@ class MistralTokenizer(TokenizerLike):
return self.transformers_tokenizer.convert_tokens_to_ids(tokens) return self.transformers_tokenizer.convert_tokens_to_ids(tokens)
def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_tokens_to_string(self, tokens: list[str]) -> str:
to_decode_special_tokens = {SpecialTokens.tool_calls} to_decode_special_tokens = {
SpecialTokens.tool_calls,
SpecialTokens.begin_think,
SpecialTokens.end_think,
}
if self.is_tekken: if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [ tokens = [
......
...@@ -241,7 +241,10 @@ class MistralToolParser(ToolParser): ...@@ -241,7 +241,10 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> DeltaMessage | None: ) -> DeltaMessage | None:
if self.bot_token_id not in current_token_ids: has_bot_token = (
self.bot_token_id in current_token_ids or self.bot_token in current_text
)
if not has_bot_token:
# if the tool call token is not in the tokens generated so far, # if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool # append output to contents since it's not a tool
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
...@@ -275,7 +278,8 @@ class MistralToolParser(ToolParser): ...@@ -275,7 +278,8 @@ class MistralToolParser(ToolParser):
additional_content: str = "" additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call # this is the first tool call
assert self.bot_token_id in delta_token_ids if self.bot_token not in delta_text:
return DeltaMessage(content=delta_text)
if not delta_text.startswith(self.bot_token): if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0] additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join( delta_text = self.bot_token + "".join(
...@@ -411,7 +415,7 @@ class MistralToolParser(ToolParser): ...@@ -411,7 +415,7 @@ class MistralToolParser(ToolParser):
index=self.current_tool_id, type="function" index=self.current_tool_id, type="function"
) )
current_tool_call_modified = False current_tool_call_modified = False
if self.bot_token_id in delta_token_ids: if self.bot_token_id in delta_token_ids or self.bot_token in delta_text:
# this is the first tool call # this is the first tool call
if not delta_text.startswith(self.bot_token): if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0] content = delta_text.split(self.bot_token)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment