Unverified Commit e42b49bd authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files
parent 4a718e77
...@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs ...@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0 pyzmq >= 25.0.0
msgspec msgspec
gguf >= 0.17.0 gguf >= 0.17.0
mistral_common[image] >= 1.9.1 mistral_common[image] >= 1.10.0
opencv-python-headless >= 4.13.0 # required for video IO opencv-python-headless >= 4.13.0 # required for video IO
pyyaml pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
......
...@@ -95,7 +95,7 @@ transformers==4.57.5 ...@@ -95,7 +95,7 @@ transformers==4.57.5
# Pin HF Hub version # Pin HF Hub version
huggingface-hub==0.36.2 huggingface-hub==0.36.2
# Pin Mistral Common # Pin Mistral Common
mistral-common[image,audio]==1.9.1 mistral-common[image,audio]==1.10.0
# Required for Prithvi tests # Required for Prithvi tests
terratorch==1.2.2 terratorch==1.2.2
# Required for Prithvi tests # Required for Prithvi tests
......
...@@ -482,7 +482,7 @@ mbstrdecoder==1.1.3 ...@@ -482,7 +482,7 @@ mbstrdecoder==1.1.3
# typepy # typepy
mdurl==0.1.2 mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
mistral-common==1.9.1 mistral-common==1.10.0
# via -r requirements/test.in # via -r requirements/test.in
more-itertools==10.5.0 more-itertools==10.5.0
# via lm-eval # via lm-eval
......
...@@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, cast, overload ...@@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, cast, overload
from mistral_common.protocol.instruct.request import ( from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest, ChatCompletionRequest as MistralChatCompletionRequest,
) )
from mistral_common.protocol.instruct.request import (
ReasoningEffort,
)
from mistral_common.protocol.instruct.tool_calls import Function, Tool from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.base import ( from mistral_common.tokens.tokenizers.base import (
...@@ -192,6 +195,15 @@ def validate_request_params(request: "ChatCompletionRequest"): ...@@ -192,6 +195,15 @@ def validate_request_params(request: "ChatCompletionRequest"):
if request.chat_template is not None or request.chat_template_kwargs is not None: if request.chat_template is not None or request.chat_template_kwargs is not None:
raise ValueError("chat_template is not supported for Mistral tokenizers.") raise ValueError("chat_template is not supported for Mistral tokenizers.")
if request.reasoning_effort and request.reasoning_effort not in list(
ReasoningEffort
):
raise ValueError(
f"reasoning_effort={request.reasoning_effort} is not supported by "
"Mistral models. Supported values are: "
f"{[e.value for e in ReasoningEffort]}."
)
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
assert isinstance(tokenizer, Tekkenizer), type(tokenizer) assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
...@@ -419,6 +431,12 @@ class MistralTokenizer(TokenizerLike): ...@@ -419,6 +431,12 @@ class MistralTokenizer(TokenizerLike):
truncation = kwargs.get("truncation", False) truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length") max_length = kwargs.get("max_length")
version_kwargs = {}
# NOTE: This is for backward compatibility.
# Transformers should be passed arguments it knows.
if self.version >= 15:
version_kwargs["reasoning_effort"] = kwargs.get("reasoning_effort")
messages, tools = _prepare_apply_chat_template_tools_and_messages( messages, tools = _prepare_apply_chat_template_tools_and_messages(
messages, tools, continue_final_message, add_generation_prompt messages, tools, continue_final_message, add_generation_prompt
) )
...@@ -433,6 +451,7 @@ class MistralTokenizer(TokenizerLike): ...@@ -433,6 +451,7 @@ class MistralTokenizer(TokenizerLike):
max_length=max_length, max_length=max_length,
return_tensors=None, return_tensors=None,
return_dict=False, return_dict=False,
**version_kwargs,
) )
def decode( def decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment