Unverified Commit c021814f authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat(frontend): allowing passing vllm chat processor specific flags in frontend (#7896)

parent 942070c2
...@@ -114,6 +114,17 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac ...@@ -114,6 +114,17 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac
vllm_flags = None vllm_flags = None
sglang_flags = None sglang_flags = None
# --trust-remote-code is only meaningful with --dyn-chat-processor vllm.
# Warn and strip it when a different (or no) chat processor is active so
# it does not propagate as an unknown-argument error below.
if "--trust-remote-code" in unknown and config.chat_processor != "vllm":
logger.warning(
"--trust-remote-code has no effect without '--dyn-chat-processor vllm'. "
"It is only supported by the vLLM chat processor. "
"Pass '--dyn-chat-processor vllm' to enable trust_remote_code."
)
unknown = [arg for arg in unknown if arg != "--trust-remote-code"]
# parse extra vllm flags using vllm native parser. # parse extra vllm flags using vllm native parser.
if config.chat_processor == "vllm": if config.chat_processor == "vllm":
try: try:
......
...@@ -80,6 +80,7 @@ def _prepare_request( ...@@ -80,6 +80,7 @@ def _prepare_request(
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
exclude_tools_when_tool_choice_none: bool = True, exclude_tools_when_tool_choice_none: bool = True,
enable_auto_tool_choice: bool = False,
) -> tuple[ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, ChatParams]: ) -> tuple[ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, ChatParams]:
"""Validate request and build arguments for template rendering. """Validate request and build arguments for template rendering.
...@@ -103,7 +104,11 @@ def _prepare_request( ...@@ -103,7 +104,11 @@ def _prepare_request(
request_for_sampling = ChatCompletionRequest.model_validate(request) request_for_sampling = ChatCompletionRequest.model_validate(request)
tool_parser: ToolParser | None = None tool_parser: ToolParser | None = None
if tool_parser_class and request_for_sampling.tools: # With enable_auto_tool_choice the model may emit tool calls even when the
# client did not supply an explicit `tools` list, so we activate the parser
# whenever the tool_parser_class is available.
has_tools = bool(request_for_sampling.tools)
if tool_parser_class and (has_tools or enable_auto_tool_choice):
if request_for_sampling.tool_choice != "none": if request_for_sampling.tool_choice != "none":
tool_parser = tool_parser_class(tokenizer) tool_parser = tool_parser_class(tokenizer)
request_for_sampling = tool_parser.adjust_request(request_for_sampling) request_for_sampling = tool_parser.adjust_request(request_for_sampling)
...@@ -163,6 +168,7 @@ async def preprocess_chat_request( ...@@ -163,6 +168,7 @@ async def preprocess_chat_request(
renderer, renderer,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
exclude_tools_when_tool_choice_none: bool = True, exclude_tools_when_tool_choice_none: bool = True,
enable_auto_tool_choice: bool = False,
) -> PreprocessResult: ) -> PreprocessResult:
( (
request_for_sampling, request_for_sampling,
...@@ -175,6 +181,7 @@ async def preprocess_chat_request( ...@@ -175,6 +181,7 @@ async def preprocess_chat_request(
tokenizer=tokenizer, tokenizer=tokenizer,
tool_parser_class=tool_parser_class, tool_parser_class=tool_parser_class,
exclude_tools_when_tool_choice_none=exclude_tools_when_tool_choice_none, exclude_tools_when_tool_choice_none=exclude_tools_when_tool_choice_none,
enable_auto_tool_choice=enable_auto_tool_choice,
) )
_, engine_prompt = await renderer.render_messages_async(messages, chat_params) _, engine_prompt = await renderer.render_messages_async(messages, chat_params)
......
...@@ -77,6 +77,7 @@ class VllmProcessor: ...@@ -77,6 +77,7 @@ class VllmProcessor:
output_processor: OutputProcessor, output_processor: OutputProcessor,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
reasoning_parser_class: type[ReasoningParser] | None, reasoning_parser_class: type[ReasoningParser] | None,
enable_auto_tool_choice: bool = False,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.input_processor = input_processor self.input_processor = input_processor
...@@ -86,6 +87,7 @@ class VllmProcessor: ...@@ -86,6 +87,7 @@ class VllmProcessor:
self.tool_parser_class = tool_parser_class self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class self.reasoning_parser_class = reasoning_parser_class
self.exclude_tools_when_tool_choice_none = True self.exclude_tools_when_tool_choice_none = True
self.enable_auto_tool_choice = enable_auto_tool_choice
def _get_eos_token_ids(self) -> list[int]: def _get_eos_token_ids(self) -> list[int]:
"""Return EOS token ids using tokenizer metadata. """Return EOS token ids using tokenizer metadata.
...@@ -144,6 +146,7 @@ class VllmProcessor: ...@@ -144,6 +146,7 @@ class VllmProcessor:
renderer=self.input_processor.renderer, renderer=self.input_processor.renderer,
tool_parser_class=self.tool_parser_class, tool_parser_class=self.tool_parser_class,
exclude_tools_when_tool_choice_none=self.exclude_tools_when_tool_choice_none, exclude_tools_when_tool_choice_none=self.exclude_tools_when_tool_choice_none,
enable_auto_tool_choice=self.enable_auto_tool_choice,
) )
request_for_sampling = pre.request_for_sampling request_for_sampling = pre.request_for_sampling
...@@ -433,11 +436,14 @@ class EngineFactory: ...@@ -433,11 +436,14 @@ class EngineFactory:
tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto" tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto"
config_format = getattr(self.flags, "config_format", None) or "auto" config_format = getattr(self.flags, "config_format", None) or "auto"
load_format = getattr(self.flags, "load_format", None) or "dummy" load_format = getattr(self.flags, "load_format", None) or "dummy"
trust_remote_code = getattr(self.flags, "trust_remote_code", False)
enable_auto_tool_choice = getattr(self.flags, "enable_auto_tool_choice", False)
model_config = ModelConfig( model_config = ModelConfig(
model=source_path, model=source_path,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
config_format=config_format, config_format=config_format,
trust_remote_code=trust_remote_code,
) )
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
...@@ -496,6 +502,7 @@ class EngineFactory: ...@@ -496,6 +502,7 @@ class EngineFactory:
output_processor, output_processor,
tool_parser_class, tool_parser_class,
reasoning_parser_class, reasoning_parser_class,
enable_auto_tool_choice=enable_auto_tool_choice,
) )
gen.exclude_tools_when_tool_choice_none = ( gen.exclude_tools_when_tool_choice_none = (
self.config.exclude_tools_when_tool_choice_none self.config.exclude_tools_when_tool_choice_none
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment