Unverified Commit 54e2f83d authored by Neil Schemenauer's avatar Neil Schemenauer Committed by GitHub
Browse files

[Feature] Lazy import for the "mistral" tokenizer module. (#34651)


Signed-off-by: default avatarNeil Schemenauer <nas@arctrix.com>
parent e631f8e7
...@@ -23,7 +23,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache ...@@ -23,7 +23,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.mistral import is_mistral_tokenizer
from ....multimodal.utils import random_audio, random_image, random_video from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import ( from ...registry import (
...@@ -183,7 +183,7 @@ def get_text_token_prompts( ...@@ -183,7 +183,7 @@ def get_text_token_prompts(
text_prompt: str | None text_prompt: str | None
token_prompt: list[int] token_prompt: list[int]
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
# ChatCompletionRequest only supports ImageChunk natively; # ChatCompletionRequest only supports ImageChunk natively;
# for other modalities (e.g. audio), fall back to the model's # for other modalities (e.g. audio), fall back to the model's
# own dummy inputs builder which knows the right placeholders. # own dummy inputs builder which knows the right placeholders.
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.mistral import is_mistral_tokenizer
class StreamingReasoningReconstructor: class StreamingReasoningReconstructor:
...@@ -59,7 +59,7 @@ def run_reasoning_extraction_mistral( ...@@ -59,7 +59,7 @@ def run_reasoning_extraction_mistral(
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
streaming: bool = False, streaming: bool = False,
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
reasoning_parser.model_tokenizer reasoning_parser.model_tokenizer
) )
if streaming: if streaming:
...@@ -130,7 +130,7 @@ def run_reasoning_extraction_streaming_mistral( ...@@ -130,7 +130,7 @@ def run_reasoning_extraction_streaming_mistral(
model_deltas: list[int], model_deltas: list[int],
request: ChatCompletionRequest | None = None, request: ChatCompletionRequest | None = None,
) -> StreamingReasoningReconstructor: ) -> StreamingReasoningReconstructor:
assert isinstance(reasoning_parser.model_tokenizer, MistralTokenizer), type( assert is_mistral_tokenizer(reasoning_parser.model_tokenizer), type(
reasoning_parser.model_tokenizer reasoning_parser.model_tokenizer
) )
request = request or ChatCompletionRequest(messages=[], model="test-model") request = request or ChatCompletionRequest(messages=[], model="test-model")
......
...@@ -83,9 +83,9 @@ from vllm.renderers.inputs.preprocess import ( ...@@ -83,9 +83,9 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter from vllm.utils.counter import Counter
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.tqdm_utils import maybe_tqdm from vllm.utils.tqdm_utils import maybe_tqdm
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor from vllm.v1.sample.logits_processor import LogitsProcessor
...@@ -891,7 +891,7 @@ class LLM: ...@@ -891,7 +891,7 @@ class LLM:
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools, tools=tools,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer), tokenize=is_mistral_tokenizer(renderer.tokenizer),
), ),
), ),
) )
...@@ -1458,7 +1458,7 @@ class LLM: ...@@ -1458,7 +1458,7 @@ class LLM:
model_config = self.model_config model_config = self.model_config
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer") raise ValueError("Score API is not supported for Mistral tokenizer")
if len(data_1) == 1: if len(data_1) == 1:
......
...@@ -75,16 +75,12 @@ from vllm.parser import ParserManager ...@@ -75,16 +75,12 @@ from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
)
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -244,18 +240,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -244,18 +240,18 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser tool_parser = self.tool_parser
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) # type: ignore[arg-type] _mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
validate_request_params(request) _mt.validate_request_params(request)
# Check if tool parsing is unavailable (common condition) # Check if tool parsing is unavailable (common condition)
tool_parsing_unavailable = ( tool_parsing_unavailable = (
tool_parser is None tool_parser is None
and not isinstance(tokenizer, MistralTokenizer) and not is_mistral_tokenizer(tokenizer)
and not self.use_harmony and not self.use_harmony
) )
...@@ -639,8 +635,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -639,8 +635,6 @@ class OpenAIServingChat(OpenAIServing):
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None, reasoning_parser: ReasoningParser | None = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
...@@ -955,7 +949,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -955,7 +949,7 @@ class OpenAIServingChat(OpenAIServing):
) )
else: else:
# Generate ID based on tokenizer type # Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
tool_call_id = MistralToolCall.generate_random_id() tool_call_id = MistralToolCall.generate_random_id()
else: else:
tool_call_id = make_tool_call_id( tool_call_id = make_tool_call_id(
...@@ -1516,7 +1510,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1516,7 +1510,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
tool_call_class = ( tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
) )
if self.use_harmony: if self.use_harmony:
# Harmony models already have parsed content and tool_calls # Harmony models already have parsed content and tool_calls
...@@ -1951,7 +1945,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1951,7 +1945,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) # type: ignore[arg-type] _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
......
...@@ -128,6 +128,7 @@ from vllm.utils.async_utils import ( ...@@ -128,6 +128,7 @@ from vllm.utils.async_utils import (
collect_from_async_generator, collect_from_async_generator,
merge_async_iterators, merge_async_iterators,
) )
from vllm.utils.mistral import is_mistral_tokenizer
class GenerationError(Exception): class GenerationError(Exception):
...@@ -976,15 +977,13 @@ class OpenAIServing: ...@@ -976,15 +977,13 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer renderer = self.renderer
default_template_kwargs = merge_kwargs( default_template_kwargs = merge_kwargs(
default_template_kwargs, default_template_kwargs,
dict( dict(
tools=tool_dicts, tools=tool_dicts,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer), tokenize=is_mistral_tokenizer(renderer.tokenizer),
), ),
) )
......
...@@ -41,8 +41,8 @@ from vllm.logger import init_logger ...@@ -41,8 +41,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async, merge_async_iterators from vllm.utils.async_utils import make_async, merge_async_iterators
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -348,7 +348,7 @@ class ServingScores(OpenAIServing): ...@@ -348,7 +348,7 @@ class ServingScores(OpenAIServing):
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse: ) -> list[PoolingRequestOutput] | ErrorResponse:
tokenizer = self.renderer.get_tokenizer() tokenizer = self.renderer.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding") raise ValueError("MistralTokenizer not supported for cross-encoding")
model_config = self.model_config model_config = self.model_config
......
...@@ -26,6 +26,7 @@ from vllm.tokenizers import TokenizerLike ...@@ -26,6 +26,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import JSONTree, json_map_leaves
from vllm.utils.mistral import is_mistral_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
...@@ -260,10 +261,8 @@ class InputProcessingContext: ...@@ -260,10 +261,8 @@ class InputProcessingContext:
typ = ProcessorMixin typ = ProcessorMixin
from vllm.tokenizers.mistral import MistralTokenizer
tokenizer = self.tokenizer tokenizer = self.tokenizer
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
tokenizer = tokenizer.transformers_tokenizer tokenizer = tokenizer.transformers_tokenizer
merged_kwargs = self.get_merged_mm_kwargs(kwargs) merged_kwargs = self.get_merged_mm_kwargs(kwargs)
......
...@@ -16,6 +16,7 @@ from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig ...@@ -16,6 +16,7 @@ from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.serial_utils import PydanticMsgspecMixin from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -731,7 +732,6 @@ class SamplingParams( ...@@ -731,7 +732,6 @@ class SamplingParams(
): ):
raise ValueError("structured_outputs.grammar cannot be an empty string") raise ValueError("structured_outputs.grammar cannot be an empty string")
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
has_guidance_unsupported_json_features, has_guidance_unsupported_json_features,
validate_guidance_grammar, validate_guidance_grammar,
...@@ -752,7 +752,7 @@ class SamplingParams( ...@@ -752,7 +752,7 @@ class SamplingParams(
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
raise ValueError( raise ValueError(
"Mistral tokenizer is not supported for the 'guidance' " "Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] " "structured output backend. Please use ['xgrammar', 'outlines'] "
...@@ -764,7 +764,7 @@ class SamplingParams( ...@@ -764,7 +764,7 @@ class SamplingParams(
validate_structured_output_request_outlines(self) validate_structured_output_request_outlines(self)
elif backend == "lm-format-enforcer": elif backend == "lm-format-enforcer":
# lm format enforcer backend # lm format enforcer backend
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
raise ValueError( raise ValueError(
"Mistral tokenizer is not supported for the 'lm-format-enforcer' " "Mistral tokenizer is not supported for the 'lm-format-enforcer' "
"structured output backend. Please use ['xgrammar', 'outlines'] " "structured output backend. Please use ['xgrammar', 'outlines'] "
...@@ -796,7 +796,7 @@ class SamplingParams( ...@@ -796,7 +796,7 @@ class SamplingParams(
schema = so_params.json schema = so_params.json
skip_guidance = has_guidance_unsupported_json_features(schema) skip_guidance = has_guidance_unsupported_json_features(schema)
if isinstance(tokenizer, MistralTokenizer) or skip_guidance: if is_mistral_tokenizer(tokenizer) or skip_guidance:
# Fall back to outlines if the tokenizer is Mistral # Fall back to outlines if the tokenizer is Mistral
# or if schema contains features unsupported by guidance # or if schema contains features unsupported by guidance
validate_structured_output_request_outlines(self) validate_structured_output_request_outlines(self)
......
...@@ -210,6 +210,8 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: ...@@ -210,6 +210,8 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
class MistralTokenizer(TokenizerLike): class MistralTokenizer(TokenizerLike):
IS_MISTRAL_TOKENIZER = True # used by vllm.utils.mistral
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
......
...@@ -22,10 +22,10 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,10 +22,10 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,7 +34,7 @@ class Hermes2ProToolParser(ToolParser): ...@@ -34,7 +34,7 @@ class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model") logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = tokenizer.tokenizer self.model_tokenizer = tokenizer.tokenizer
......
...@@ -22,9 +22,9 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,9 +22,9 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.utils import extract_intermediate_diff from vllm.tool_parsers.utils import extract_intermediate_diff
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,7 +33,7 @@ class JambaToolParser(ToolParser): ...@@ -33,7 +33,7 @@ class JambaToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if is_mistral_tokenizer(self.model_tokenizer):
raise ValueError( raise ValueError(
"Detected a MistralTokenizer tokenizer when using a Jamba model" "Detected a MistralTokenizer tokenizer when using a Jamba model"
) )
......
...@@ -25,10 +25,10 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -25,10 +25,10 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -66,9 +66,7 @@ class MistralToolCall(ToolCall): ...@@ -66,9 +66,7 @@ class MistralToolCall(ToolCall):
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not ( return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
class MistralToolParser(ToolParser): class MistralToolParser(ToolParser):
...@@ -83,7 +81,7 @@ class MistralToolParser(ToolParser): ...@@ -83,7 +81,7 @@ class MistralToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer): if not is_mistral_tokenizer(self.model_tokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...") logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
# initialize properties used for state when parsing tool calls in # initialize properties used for state when parsing tool calls in
...@@ -115,7 +113,7 @@ class MistralToolParser(ToolParser): ...@@ -115,7 +113,7 @@ class MistralToolParser(ToolParser):
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if ( if (
not isinstance(self.model_tokenizer, MistralTokenizer) not is_mistral_tokenizer(self.model_tokenizer)
and request.tools and request.tools
and request.tool_choice != "none" and request.tool_choice != "none"
): ):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Provides lazy import of the vllm.tokenizers.mistral module."""
from __future__ import annotations
from typing import TYPE_CHECKING, TypeGuard
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
# if type checking, eagerly import the module
import vllm.tokenizers.mistral as mt
else:
mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral")
def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]:
"""Return true if the tokenizer is a MistralTokenizer instance."""
cls = type(obj)
# Check for special class attribute, this avoids importing the class to
# do an isinstance() check. If the attribute is True, do an isinstance
# check to be sure we have the correct type.
return bool(
getattr(cls, "IS_MISTRAL_TOKENIZER", False)
and isinstance(obj, mt.MistralTokenizer)
)
...@@ -10,8 +10,8 @@ import torch ...@@ -10,8 +10,8 @@ import torch
import vllm.envs import vllm.envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.v1.structured_output.backend_types import ( from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend, StructuredOutputBackend,
StructuredOutputGrammar, StructuredOutputGrammar,
...@@ -38,7 +38,7 @@ class XgrammarBackend(StructuredOutputBackend): ...@@ -38,7 +38,7 @@ class XgrammarBackend(StructuredOutputBackend):
self.vllm_config.structured_outputs_config.disable_any_whitespace self.vllm_config.structured_outputs_config.disable_any_whitespace
) )
if isinstance(self.tokenizer, MistralTokenizer): if is_mistral_tokenizer(self.tokenizer):
# NOTE: ideally, xgrammar should handle this accordingly. # NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
stop_token_ids = [self.tokenizer.eos_token_id] stop_token_ids = [self.tokenizer.eos_token_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment