Unverified Commit 57430fc9 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Default model load/config/tokenizer to `mistral` format if relevant files exist (#28659)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
Signed-off-by: default avatarJulien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent c68c7b40
...@@ -142,7 +142,7 @@ Flags: `--tool-call-parser hermes` ...@@ -142,7 +142,7 @@ Flags: `--tool-call-parser hermes`
Supported models: Supported models:
* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) * `mistralai/Mistral-7B-Instruct-v0.3` (confirmed)
* Additional mistral function-calling models are compatible as well. * Additional Mistral function-calling models are compatible as well.
Known issues: Known issues:
...@@ -158,12 +158,25 @@ Known issues: ...@@ -158,12 +158,25 @@ Known issues:
Recommended flags: Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend: 1. To use the official Mistral AI's format:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral` `--tool-call-parser mistral`
2. To use the default Transformers tokenization backend: 2. To use the Transformers format when available:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
`--tokenizer_mode hf --config_format hf --load_format hf --tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
!!! note
Models officially released by Mistral AI have two possible formats:
1. The official format that is used by default with `auto` or `mistral` arguments:
`--tokenizer_mode mistral --config_format mistral --load_format mistral`
This format uses [mistral-common](https://github.com/mistralai/mistral-common), the Mistral AI's tokenizer backend.
2. The Transformers format, when available, that is used with `hf` arguments:
`--tokenizer_mode hf --config_format hf --load_format hf --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`) ### Llama Models (`llama3_json`)
......
...@@ -208,7 +208,7 @@ def test_mistral_format( ...@@ -208,7 +208,7 @@ def test_mistral_format(
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
tokenizer_mode="auto", tokenizer_mode="hf",
load_format="safetensors", load_format="safetensors",
config_format="hf", config_format="hf",
) as hf_format_model: ) as hf_format_model:
......
...@@ -50,12 +50,24 @@ def test_hf_model_weights_mapper(model_arch: str): ...@@ -50,12 +50,24 @@ def test_hf_model_weights_mapper(model_arch: str):
model_info.check_available_online(on_fail="skip") model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
is_mistral_model = model_arch in [
"Mistral3ForConditionalGeneration",
"PixtralForConditionalGeneration",
"VoxtralForConditionalGeneration",
]
if not is_mistral_model or model_info.tokenizer_mode == "mistral":
tokenizer_mode = model_info.tokenizer_mode
else:
tokenizer_mode = "hf"
model_id = model_info.default model_id = model_info.default
model_config = ModelConfig( model_config = ModelConfig(
model_id, model_id,
tokenizer=model_info.tokenizer or model_id, tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=tokenizer_mode,
config_format="hf",
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
......
...@@ -259,6 +259,9 @@ def validate_generated_texts( ...@@ -259,6 +259,9 @@ def validate_generated_texts(
tensor_parallel_size=vllm_tp_size, tensor_parallel_size=vllm_tp_size,
enforce_eager=False, enforce_eager=False,
default_torch_num_threads=1, default_torch_num_threads=1,
tokenizer_mode="hf",
load_format="hf",
config_format="hf",
) as llm: ) as llm:
vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_outputs = llm.generate_greedy(prompts, max_tokens)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
......
...@@ -128,6 +128,12 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -128,6 +128,12 @@ CONFIGS: dict[str, ServerConfig] = {
"arguments": [ "arguments": [
"--enforce-eager", "--enforce-eager",
"--no-enable-prefix-caching", "--no-enable-prefix-caching",
"--tokenizer_mode",
"hf",
"--load_format",
"hf",
"--config_format",
"hf",
"--tool-call-parser", "--tool-call-parser",
"mistral", "mistral",
"--chat-template", "--chat-template",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, call, patch
import pytest
from vllm.transformers_utils.config import list_filtered_repo_files
@pytest.mark.parametrize(
"allow_patterns,expected_relative_files",
[
(
["*.json", "correct*.txt"],
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
),
],
)
def test_list_filtered_repo_files(
allow_patterns: list[str], expected_relative_files: list[str]
):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "json_file.json").touch()
(path_tmp_dir / "correct_2.txt").touch()
(path_tmp_dir / "uncorrect.txt").touch()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
(subfolder / "uncorrect_sub.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.config.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
out_files = sorted(
list_filtered_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
)
assert out_files == sorted(expected_relative_files)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
...@@ -2,7 +2,11 @@ ...@@ -2,7 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.transformers_utils.utils import is_cloud_storage, is_gcs, is_s3 from vllm.transformers_utils.utils import (
is_cloud_storage,
is_gcs,
is_s3,
)
def test_is_gcs(): def test_is_gcs():
......
...@@ -46,11 +46,15 @@ EAGLE_SPEC_CONFIG = { ...@@ -46,11 +46,15 @@ EAGLE_SPEC_CONFIG = {
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), # FIXME: Since "auto" will use Mistral tokenizer and these backends do not support
# it, we skip these tests for now.
# ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", None),
pytest.param( pytest.param(
"mistralai/Ministral-8B-Instruct-2410", "mistralai/Ministral-8B-Instruct-2410",
"lm-format-enforcer", "lm-format-enforcer",
"auto", "hf",
None, None,
marks=pytest.mark.skip( marks=pytest.mark.skip(
reason=( reason=(
...@@ -80,7 +84,7 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ...@@ -80,7 +84,7 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
# ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), # ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG), ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG),
] ]
...@@ -151,6 +155,8 @@ def test_structured_output( ...@@ -151,6 +155,8 @@ def test_structured_output(
), ),
seed=120, seed=120,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
load_format="auto" if not model_name.startswith("mistralai/") else "hf",
config_format="auto" if not model_name.startswith("mistralai/") else "hf",
speculative_config=speculative_config, speculative_config=speculative_config,
) )
...@@ -720,6 +726,8 @@ def test_structured_output_auto_mode( ...@@ -720,6 +726,8 @@ def test_structured_output_auto_mode(
max_model_len=1024, max_model_len=1024,
structured_outputs_config=dict(backend="auto"), structured_outputs_config=dict(backend="auto"),
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
load_format="auto",
config_format="auto",
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
......
...@@ -81,7 +81,7 @@ TaskOption = Literal[ ...@@ -81,7 +81,7 @@ TaskOption = Literal[
"transcription", "transcription",
"draft", "draft",
] ]
TokenizerMode = Literal["auto", "slow", "mistral", "custom"] TokenizerMode = Literal["auto", "hf", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[ LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
...@@ -130,7 +130,8 @@ class ModelConfig: ...@@ -130,7 +130,8 @@ class ModelConfig:
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode = "auto" tokenizer_mode: TokenizerMode = "auto"
"""Tokenizer mode:\n """Tokenizer mode:\n
- "auto" will use the fast tokenizer if available.\n - "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n - "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n - "mistral" will always use the tokenizer from `mistral_common`.\n
- "custom" will use --tokenizer to select the preregistered tokenizer.""" - "custom" will use --tokenizer to select the preregistered tokenizer."""
...@@ -241,8 +242,8 @@ class ModelConfig: ...@@ -241,8 +242,8 @@ class ModelConfig:
first one.""" first one."""
config_format: str | ConfigFormat = "auto" config_format: str | ConfigFormat = "auto"
"""The format of the model config to load:\n """The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it - "auto" will try to load the config in hf format if available after trying
will try to load in mistral format.\n to load in mistral format.\n
- "hf" will load the config in hf format.\n - "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format.""" - "mistral" will load the config in mistral format."""
hf_token: bool | str | None = None hf_token: bool | str | None = None
......
...@@ -30,6 +30,7 @@ logger = init_logger(__name__) ...@@ -30,6 +30,7 @@ logger = init_logger(__name__)
# if a new load format is added here # if a new load format is added here
LoadFormats = Literal[ LoadFormats = Literal[
"auto", "auto",
"hf",
"bitsandbytes", "bitsandbytes",
"dummy", "dummy",
"fastsafetensors", "fastsafetensors",
...@@ -45,6 +46,7 @@ LoadFormats = Literal[ ...@@ -45,6 +46,7 @@ LoadFormats = Literal[
] ]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = { _LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader, "auto": DefaultModelLoader,
"hf": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader, "bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader, "dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader, "fastsafetensors": DefaultModelLoader,
......
...@@ -31,6 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -31,6 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator, safetensors_weights_iterator,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import list_filtered_repo_files
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -96,8 +97,25 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -96,8 +97,25 @@ class DefaultModelLoader(BaseModelLoader):
load_format = self.load_config.load_format load_format = self.load_config.load_format
use_safetensors = False use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
# First check for 'auto' format that mistral files format are present.
# This is to load mistral models with official format by default.
if load_format == "auto": if load_format == "auto":
load_format = (
"mistral"
if len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
)
)
> 0
else "hf"
)
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"] allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors": elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True use_safetensors = True
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
import json import json
import os import os
import time import time
...@@ -355,6 +356,41 @@ def list_repo_files( ...@@ -355,6 +356,41 @@ def list_repo_files(
return with_retry(lookup_files, "Error retrieving file list") return with_retry(lookup_files, "Error retrieving file list")
def list_filtered_repo_files(
model_name_or_path: str,
allow_patterns: list[str],
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> list[str]:
try:
all_files = list_repo_files(
repo_id=model_name_or_path,
revision=revision,
token=token,
repo_type=repo_type,
)
except Exception:
logger.error(
"Error retrieving file list. Please ensure your `model_name_or_path`"
"`repo_type`, `token` and `revision` arguments are correctly set. "
"Returning an empty list."
)
return []
file_list = []
# Filter patterns on filenames
for pattern in allow_patterns:
file_list.extend(
[
file
for file in all_files
if fnmatch.fnmatch(os.path.basename(file), pattern)
]
)
return file_list
def file_exists( def file_exists(
repo_id: str, repo_id: str,
file_name: str, file_name: str,
...@@ -619,10 +655,14 @@ def get_config( ...@@ -619,10 +655,14 @@ def get_config(
if config_format == "auto": if config_format == "auto":
try: try:
if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): # First check for Mistral to avoid defaulting to
config_format = "hf" # Transformers implementation.
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral" config_format = "mistral"
elif is_gguf or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
config_format = "hf"
else: else:
raise ValueError( raise ValueError(
"Could not detect config format for no config file found. " "Could not detect config format for no config file found. "
......
...@@ -118,7 +118,7 @@ def _remap_general_mistral_args(config: dict) -> dict: ...@@ -118,7 +118,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"model_type": ("model_type", "transformer"), "model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"), "hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False), "tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", 128_000), "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000), "max_position_embeddings": ("max_position_embeddings", 128_000),
} }
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import contextlib import contextlib
import copy import copy
import importlib.util
import os import os
import warnings
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias from typing import TYPE_CHECKING, Any, TypeAlias
...@@ -15,7 +15,10 @@ from typing_extensions import assert_never ...@@ -15,7 +15,10 @@ from typing_extensions import assert_never
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config from vllm.transformers_utils.config import (
get_sentence_transformer_tokenizer_config,
list_filtered_repo_files,
)
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
...@@ -182,25 +185,29 @@ def get_tokenizer( ...@@ -182,25 +185,29 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent tokenizer_name = Path(tokenizer_name).parent
# if tokenizer is from official mistral org # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" # first to use official Mistral tokenizer if possible.
if is_from_mistral_org and tokenizer_mode != "mistral": mistral_common_installed = importlib.util.find_spec("mistral_common") is not None
warnings.warn( if tokenizer_mode == "auto" and mistral_common_installed:
"It is strongly recommended to run mistral models with " allow_patterns = ["tekken.json", "tokenizer.model.v*"]
'`--tokenizer-mode "mistral"` to ensure correct ' files_list = list_filtered_repo_files(
"encoding and decoding.", model_name_or_path=str(tokenizer_name),
FutureWarning, allow_patterns=allow_patterns,
stacklevel=2, revision=revision,
) )
if len(files_list) > 0:
tokenizer_mode = "mistral"
tokenizer: AnyTokenizer tokenizer: AnyTokenizer
if tokenizer_mode == "mistral": if tokenizer_mode == "mistral":
logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained( tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision str(tokenizer_name), revision=revision
) )
elif tokenizer_mode == "custom": elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer( tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name), str(tokenizer_name),
*args, *args,
...@@ -210,6 +217,7 @@ def get_tokenizer( ...@@ -210,6 +217,7 @@ def get_tokenizer(
) )
else: else:
try: try:
logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
*args, *args,
......
...@@ -20,6 +20,7 @@ from vllm.multimodal.utils import argsort_mm_positions ...@@ -20,6 +20,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
...@@ -300,12 +301,24 @@ class Processor: ...@@ -300,12 +301,24 @@ class Processor:
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar(params, tokenizer=None) validate_guidance_grammar(params, tokenizer=None)
elif backend == "outlines": elif backend == "outlines":
# outlines backend # outlines backend
validate_structured_output_request_outlines(params) validate_structured_output_request_outlines(params)
elif backend == "lm-format-enforcer": elif backend == "lm-format-enforcer":
# lm format enforcer backend # lm format enforcer backend
if isinstance(self.tokenizer, MistralTokenizer):
raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_structured_output_request_lm_format_enforcer(params) validate_structured_output_request_lm_format_enforcer(params)
else: else:
# NOTE: backend must be "auto" here, because we have # NOTE: backend must be "auto" here, because we have
...@@ -320,9 +333,15 @@ class Processor: ...@@ -320,9 +333,15 @@ class Processor:
except ValueError: except ValueError:
# The request either failed validation # The request either failed validation
# or includes some jsonschema feature(s) that # or includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance. # are not supported in xgrammar.
validate_guidance_grammar(params, tokenizer=None) if isinstance(self.tokenizer, MistralTokenizer):
params.structured_outputs._backend = "guidance" # Fall back to outlines if the tokenizer is Mistral
validate_structured_output_request_outlines(params)
params.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(params, tokenizer=None)
params.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically # Remember that this backend was set automatically
params.structured_outputs._backend_was_auto = True params.structured_outputs._backend_was_auto = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment