"examples/vscode:/vscode.git/clone" did not exist on "29283e89762a3d572c504e5ea317351696b553a6"
Commit 04c2b269 authored by Russell Bryant's avatar Russell Bryant Committed by simon-mo
Browse files

Add filtering for chat template kwargs (#25794)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarsimon-mo <simon.mo@hey.com>
parent ee10d7e6
...@@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, ...@@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format, resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template) resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
...@@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" ...@@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
...@@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ...@@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside # NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json` # processor class instead of using `tokenizer_config.json`
# yapf: disable # yapf: disable
......
...@@ -11,7 +11,12 @@ from pathlib import Path ...@@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast) cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
...@@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import ( ...@@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable # yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1548,6 +1553,46 @@ def parse_chat_messages_futures( ...@@ -1548,6 +1553,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template( def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
...@@ -1573,12 +1618,17 @@ def apply_hf_chat_template( ...@@ -1573,12 +1618,17 @@ def apply_hf_chat_template(
) )
try: try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template, chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **resolved_kwargs,
) )
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's
......
...@@ -1716,6 +1716,7 @@ async def init_app_state( ...@@ -1716,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none=args.
......
...@@ -103,9 +103,13 @@ class FrontendArgs: ...@@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto" chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template. """The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"` * "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI * "openai" will render the content as a list of dictionaries, similar to
schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None ssl_keyfile: Optional[str] = None
......
...@@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "", reasoning_parser: str = "",
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
...@@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role self.response_role = response_role
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up tool use # set up tool use
...@@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing): ...@@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
( (
conversation, conversation,
request_prompts, request_prompts,
...@@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self. chat_template_content_format=self.
chat_template_content_format, chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment