Unverified Commit 56bdf85e authored by Neil Schemenauer's avatar Neil Schemenauer Committed by GitHub
Browse files

[Feature] Avoid eager import of the "mistral_common" package. (#40043)


Signed-off-by: default avatarNeil Schemenauer <nas@arctrix.com>
parent eba73068
...@@ -73,13 +73,9 @@ from vllm.reasoning import ReasoningParser ...@@ -73,13 +73,9 @@ from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.tool_parsers.utils import partial_json_loads from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
...@@ -143,10 +139,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -143,10 +139,12 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=enable_auto_tools, enable_auto_tools=enable_auto_tools,
model_name=self.model_config.model, model_name=self.model_config.model,
) )
_is_mistral_tool_parser = self.tool_parser is not None and issubclass( if (
self.tool_parser, MistralToolParser is_mistral_tool_parser(self.tool_parser)
) and self.reasoning_parser_cls is not None
if _is_mistral_tool_parser and self.reasoning_parser_cls is not None: ):
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
MistralToolParser.model_can_reason = True MistralToolParser.model_can_reason = True
self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
...@@ -823,6 +821,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -823,6 +821,10 @@ class OpenAIServingChat(OpenAIServing):
harmony_tools_streamed[i] |= tools_streamed_flag harmony_tools_streamed[i] |= tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming # Mistral grammar path: combined reasoning + tool streaming
elif is_mistral_grammar_path: elif is_mistral_grammar_path:
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolParser,
)
assert tool_parser is not None assert tool_parser is not None
assert isinstance(tool_parser, MistralToolParser) assert isinstance(tool_parser, MistralToolParser)
assert reasoning_end_arr is not None assert reasoning_end_arr is not None
...@@ -904,6 +906,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -904,6 +906,10 @@ class OpenAIServingChat(OpenAIServing):
else: else:
# Generate ID based on tokenizer type # Generate ID based on tokenizer type
if is_mistral_tokenizer(tokenizer): if is_mistral_tokenizer(tokenizer):
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
)
tool_call_id = MistralToolCall.generate_random_id() tool_call_id = MistralToolCall.generate_random_id()
else: else:
tool_call_id = make_tool_call_id( tool_call_id = make_tool_call_id(
...@@ -1275,8 +1281,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1275,8 +1281,6 @@ class OpenAIServingChat(OpenAIServing):
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
reasoning_parser: ReasoningParser | None = None, reasoning_parser: ReasoningParser | None = None,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput | None = None final_res: RequestOutput | None = None
...@@ -1393,12 +1397,17 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1393,12 +1397,17 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
tool_call_class = ( if is_mistral_tokenizer(tokenizer):
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
)
tool_call_class: type[ToolCall] = MistralToolCall
else:
tool_call_class = ToolCall
use_mistral_tool_parser = request._grammar_from_tool_parser use_mistral_tool_parser = request._grammar_from_tool_parser
if use_mistral_tool_parser: if use_mistral_tool_parser:
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
tool_call_items = MistralToolParser.build_non_streaming_tool_calls( tool_call_items = MistralToolParser.build_non_streaming_tool_calls(
tool_calls tool_calls
) )
...@@ -1436,7 +1445,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1436,7 +1445,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random), # Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve # but leave it to the class if it's Mistral to preserve
# 9-char IDs # 9-char IDs
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
tool_call_class_items.append(tool_call_class(function=tc)) tool_call_class_items.append(tool_call_class(function=tc))
else: else:
generated_id = make_tool_call_id( generated_id = make_tool_call_id(
...@@ -1469,7 +1478,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1469,7 +1478,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random), # Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve # but leave it to the class if it's Mistral to preserve
# 9-char IDs # 9-char IDs
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
tool_call_class_items.append( tool_call_class_items.append(
tool_call_class(function=tool_call) tool_call_class(function=tool_call)
) )
...@@ -1519,7 +1528,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1519,7 +1528,7 @@ class OpenAIServingChat(OpenAIServing):
# Generate ID using the correct format (kimi_k2 or random), # Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve # but leave it to the class if it's Mistral to preserve
# 9-char IDs # 9-char IDs
if isinstance(tokenizer, MistralTokenizer): if is_mistral_tokenizer(tokenizer):
tool_call_items.append(tool_call_class(function=tc)) tool_call_items.append(tool_call_class(function=tc))
else: else:
generated_id = make_tool_call_id( generated_id = make_tool_call_id(
......
...@@ -65,7 +65,6 @@ from vllm.renderers.inputs.preprocess import ( ...@@ -65,7 +65,6 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tracing import ( from vllm.tracing import (
contains_trace_headers, contains_trace_headers,
extract_trace_headers, extract_trace_headers,
...@@ -73,6 +72,7 @@ from vllm.tracing import ( ...@@ -73,6 +72,7 @@ from vllm.tracing import (
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.async_utils import collect_from_async_generator from vllm.utils.async_utils import collect_from_async_generator
from vllm.utils.mistral import is_mistral_tool_parser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -615,8 +615,7 @@ class OpenAIServing: ...@@ -615,8 +615,7 @@ class OpenAIServing:
# let the parser handle the output. # let the parser handle the output.
use_mistral_tool_parser = ( use_mistral_tool_parser = (
isinstance(request, ChatCompletionRequest) isinstance(request, ChatCompletionRequest)
and tool_parser_cls is not None and is_mistral_tool_parser(tool_parser_cls)
and issubclass(tool_parser_cls, MistralToolParser)
and request._grammar_from_tool_parser and request._grammar_from_tool_parser
) )
......
...@@ -55,9 +55,8 @@ from vllm.renderers.inputs.preprocess import ( ...@@ -55,9 +55,8 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq, prompt_to_seq,
) )
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser
from vllm.utils.mistral import mt as _mt from vllm.utils.mistral import mt as _mt
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -582,7 +581,7 @@ class OpenAIServingRender: ...@@ -582,7 +581,7 @@ class OpenAIServingRender:
tool_choice = getattr(request, "tool_choice", "none") tool_choice = getattr(request, "tool_choice", "none")
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
is_mistral_grammar_eligible = ( is_mistral_grammar_eligible = (
issubclass(tool_parser, MistralToolParser) is_mistral_tool_parser(tool_parser)
and is_mistral_tokenizer(tokenizer) and is_mistral_tokenizer(tokenizer)
and tokenizer.supports_grammar and tokenizer.supports_grammar
) )
......
...@@ -118,6 +118,8 @@ class MistralToolParser(ToolParser): ...@@ -118,6 +118,8 @@ class MistralToolParser(ToolParser):
set. set.
""" """
IS_MISTRAL_TOOL_PARSER = True # used by vllm.utils.mistral
# Used to generate correct grammar in `adjust_request` # Used to generate correct grammar in `adjust_request`
model_can_reason: bool = False model_can_reason: bool = False
......
...@@ -12,8 +12,10 @@ from vllm.utils.import_utils import LazyLoader ...@@ -12,8 +12,10 @@ from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING: if TYPE_CHECKING:
# if type checking, eagerly import the module # if type checking, eagerly import the module
import vllm.tokenizers.mistral as mt import vllm.tokenizers.mistral as mt
import vllm.tool_parsers.mistral_tool_parser as mtp
else: else:
mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral") mt = LazyLoader("mt", globals(), "vllm.tokenizers.mistral")
mtp = LazyLoader("mtp", globals(), "vllm.tool_parsers.mistral_tool_parser")
def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]: def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralTokenizer]:
...@@ -26,3 +28,16 @@ def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralToken ...@@ -26,3 +28,16 @@ def is_mistral_tokenizer(obj: TokenizerLike | None) -> TypeGuard[mt.MistralToken
getattr(cls, "IS_MISTRAL_TOKENIZER", False) getattr(cls, "IS_MISTRAL_TOKENIZER", False)
and isinstance(obj, mt.MistralTokenizer) and isinstance(obj, mt.MistralTokenizer)
) )
def is_mistral_tool_parser(cls: type | None) -> bool:
"""Return true if *cls* is (a subclass of) MistralToolParser.
Uses a class attribute check so that importing
``vllm.tool_parsers.mistral_tool_parser`` — and transitively
``mistral_common`` — is not required.
"""
return bool(
getattr(cls, "IS_MISTRAL_TOOL_PARSER", False)
and issubclass(cls, mtp.MistralToolParser) # type: ignore[arg-type]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment