Unverified Commit 56bdf85e authored by Neil Schemenauer's avatar Neil Schemenauer Committed by GitHub
Browse files

[Feature] Avoid eager import of the "mistral_common" package. (#40043)


Signed-off-by: default avatarNeil Schemenauer <nas@arctrix.com>
parent eba73068
......@@ -73,13 +73,9 @@ from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
if TYPE_CHECKING:
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
......@@ -143,10 +139,12 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=enable_auto_tools,
model_name=self.model_config.model,
)
_is_mistral_tool_parser = self.tool_parser is not None and issubclass(
self.tool_parser, MistralToolParser
)
if _is_mistral_tool_parser and self.reasoning_parser_cls is not None:
if (
is_mistral_tool_parser(self.tool_parser)
and self.reasoning_parser_cls is not None
):
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
MistralToolParser.model_can_reason = True
self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
......@@ -823,6 +821,10 @@ class OpenAIServingChat(OpenAIServing):
harmony_tools_streamed[i] |= tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming
elif is_mistral_grammar_path:
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolParser,
)
assert tool_parser is not None
assert isinstance(tool_parser, MistralToolParser)
assert reasoning_end_arr is not None
......@@ -904,6 +906,10 @@ class OpenAIServingChat(OpenAIServing):
else:
# Generate ID based on tokenizer type
if is_mistral_tokenizer(tokenizer):
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
)
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
......@@ -1275,8 +1281,6 @@ class OpenAIServingChat(OpenAIServing):
request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None,
) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time())
final_res: RequestOutput | None = None
......@@ -1393,12 +1397,17 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser,
)
tool_call_class = (
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
)
if is_mistral_tokenizer(tokenizer):
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
tool_call_class: type[ToolCall] = MistralToolCall
else:
tool_call_class = ToolCall
use_mistral_tool_parser = request._grammar_from_tool_parser
if use_mistral_tool_parser:
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
tool_call_items = MistralToolParser.build_non_streaming_tool_calls(
tool_calls
)
......@@ -1436,7 +1445,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
......@@ -1469,7 +1478,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
......@@ -1519,7 +1528,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
if is_mistral_tokenizer(tokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
......
......@@ -65,7 +65,6 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
......@@ -73,6 +72,7 @@ from vllm.tracing import (
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generator
from vllm.utils.mistral import is_mistral_tool_parser
logger = init_logger(__name__)
......@@ -615,8 +615,7 @@ class OpenAIServing:
# let the parser handle the output.
use_mistral_tool_parser = (
isinstance(request, ChatCompletionRequest)
and tool_parser_cls is not None
and issubclass(tool_parser_cls, MistralToolParser)
and is_mistral_tool_parser(tool_parser_cls)
and request._grammar_from_tool_parser
)
......
......@@ -55,9 +55,8 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
from vllm.utils.mistral import mt as _mt
logger = init_logger(__name__)
......@@ -582,7 +581,7 @@ class OpenAIServingRender:
tool_choice = getattr(request, "tool_choice", "none")
tokenizer = renderer.get_tokenizer()
is_mistral_grammar_eligible = (
issubclass(tool_parser, MistralToolParser)
is_mistral_tool_parser(tool_parser)
and is_mistral_tokenizer(tokenizer)
and tokenizer.supports_grammar
)
......
......@@ -118,6 +118,8 @@ class MistralToolParser(ToolParser):
set.
"""
IS_MISTRAL_TOOL_PARSER = True # used by vllm.utils.mistral
# Used to generate correct grammar in `adjust_request`
model_can_reason: bool = False
......
......@@ -12,8 +12,10 @@ from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
# if type checking, eagerly import the module
import vllm.tokenizers.mistral as mt
import vllm.tool_parsers.mistral_tool_parser as mtp
else:
mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral")
mtp = LazyLoader("mtp", globals(), "vllm.tool_parsers.mistral_tool_parser")
def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]:
......@@ -26,3 +28,16 @@ def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralToken
getattr(cls, "IS_MISTRAL_TOKENIZER", False)
and isinstance(obj, mt.MistralTokenizer)
)
def is_mistral_tool_parser(cls: type | None) -> bool:
"""Return true if *cls* is (a subclass of) MistralToolParser.
Uses a class attribute check so that importing
``vllm.tool_parsers.mistral_tool_parser`` — and transitively
``mistral_common`` — is not required.
"""
return bool(
getattr(cls, "IS_MISTRAL_TOOL_PARSER", False)
and issubclass(cls, mtp.MistralToolParser) # type: ignore[arg-type]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment