Unverified Commit cbcdf2c6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix chat template loading (#15143)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 038de04d
...@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt, ...@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result # Call the function and get the result
result = apply_hf_chat_template( result = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=True,
conversation=mock_request.messages, conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content, chat_template=mock_request.chat_template or template_content,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt, add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message, continue_final_message=mock_request.continue_final_message,
) )
......
...@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI, ...@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309) completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
...@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded( ...@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=6299, total_tokens=6309) completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
message = choice.message message = choice.message
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
......
...@@ -4,10 +4,13 @@ import warnings ...@@ -4,10 +4,13 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format) resolve_chat_template_content_format)
...@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples" ...@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b" ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
...@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ...@@ -703,25 +708,70 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result = apply_hf_chat_template( vllm_result = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=None, chat_template=None,
tools=None,
add_generation_prompt=True, add_generation_prompt=True,
) )
assert hf_result == vllm_result assert hf_result == vllm_result
@pytest.mark.parametrize(
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
])
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
tools = [{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}] if use_tools else None
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model", "expected_format"), ("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"), [(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"), (QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"), (ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"), (MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")], (LLAMA_GUARD_MODEL_ID, "openai")],
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format): def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
enable_lora=False, enable_lora=False,
...@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
chat_template = tokenizer.chat_template # Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
print("[TEXT]") print("[TEXT]")
...@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format): ...@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template None, # Test detecting the tokenizer's chat_template
None,
"auto", "auto",
tokenizer, tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
...@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
chat_template, chat_template,
None,
"auto", "auto",
dummy_tokenizer, dummy_tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
...@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]], ...@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
# universal args for all models go here. also good if you need to test locally # universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something. # and change type or KV cache quantization or something.
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"] ARGS: list[str] = [
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
"256"
]
CONFIGS: dict[str, ServerConfig] = { CONFIGS: dict[str, ServerConfig] = {
"hermes": { "hermes": {
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import codecs
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
...@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import ( ...@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio) InputAudio)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
from typing_extensions import Required, TypeAlias, TypedDict from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -306,24 +306,63 @@ def _detect_content_format( ...@@ -306,24 +306,63 @@ def _detect_content_format(
return "openai" return "openai"
def _resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None:
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path, exc_info=True)
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
return None
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template hf_chat_template = _resolve_hf_chat_template(
else: tokenizer,
tokenizer_chat_template = None chat_template=chat_template,
trust_remote_code=trust_remote_code,
jinja_text: Optional[str] tools=tools,
if isinstance(tokenizer_chat_template, str) and chat_template is None: )
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
else: else:
jinja_text = load_chat_template(chat_template, is_literal=True) hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))
detected_format = ("string" if jinja_text is None else detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string")) _detect_content_format(jinja_text, default="string"))
...@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format( ...@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
@lru_cache @lru_cache
def resolve_chat_template_content_format( def _log_chat_template_content_format(
chat_template: Optional[str], chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, detected_format: ChatTemplateContentFormatOption,
) -> _ChatTemplateContentFormat: ):
detected_format = _resolve_chat_template_content_format(
chat_template,
given_format,
tokenizer,
)
logger.info( logger.info(
"Detected the chat template content format to be '%s'. " "Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.", "You can set `--chat-template-content-format` to override this.",
...@@ -360,6 +393,29 @@ def resolve_chat_template_content_format( ...@@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
detected_format, detected_format,
) )
def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format return detected_format
...@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): ...@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
f"{type(chat_template)} is not a valid chat template type") f"{type(chat_template)} is not a valid chat template type")
def load_chat_template( def _load_chat_template(
chat_template: Optional[Union[Path, str]], chat_template: Optional[Union[Path, str]],
*, *,
is_literal: bool = False, is_literal: bool = False,
...@@ -724,7 +780,7 @@ def load_chat_template( ...@@ -724,7 +780,7 @@ def load_chat_template(
raise TypeError("chat_template is expected to be read directly " raise TypeError("chat_template is expected to be read directly "
"from its value") "from its value")
return codecs.decode(chat_template, "unicode_escape") return chat_template
try: try:
with open(chat_template) as f: with open(chat_template) as f:
...@@ -742,7 +798,18 @@ def load_chat_template( ...@@ -742,7 +798,18 @@ def load_chat_template(
# If opening a file fails, set chat template to be args to # If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly # ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True) return _load_chat_template(chat_template, is_literal=True)
_cached_load_chat_template = lru_cache(_load_chat_template)
def load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
# TODO: Let user specify how to insert multimodal tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
...@@ -1067,23 +1134,20 @@ def apply_hf_chat_template( ...@@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*, *,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if chat_template is None: hf_chat_template = _resolve_hf_chat_template(
chat_template = tokenizer.chat_template tokenizer,
chat_template=chat_template,
# FIXME: Temporary workaround for tools=tools,
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31 trust_remote_code=trust_remote_code,
if chat_template is None: )
try:
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass
if chat_template is None: if hf_chat_template is None:
raise ValueError( raise ValueError(
"As of transformers v4.44, default chat template is no longer " "As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
...@@ -1091,7 +1155,8 @@ def apply_hf_chat_template( ...@@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template, tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **kwargs,
) )
...@@ -1100,7 +1165,8 @@ def apply_hf_chat_template( ...@@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None, chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
**kwargs: Any, **kwargs: Any,
) -> list[int]: ) -> list[int]:
if chat_template is not None: if chat_template is not None:
...@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template( ...@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,
tools=tools,
**kwargs, **kwargs,
) )
...@@ -690,8 +690,10 @@ class LLM: ...@@ -690,8 +690,10 @@ class LLM:
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tools,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
prompts: list[Union[TokensPrompt, TextPrompt]] = [] prompts: list[Union[TokensPrompt, TextPrompt]] = []
...@@ -713,18 +715,19 @@ class LLM: ...@@ -713,18 +715,19 @@ class LLM:
tokenizer, tokenizer,
messages=msgs, messages=msgs,
chat_template=chat_template, chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools,
) )
else: else:
prompt_data = apply_hf_chat_template( prompt_data = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools,
) )
prompt: Union[TokensPrompt, TextPrompt] prompt: Union[TokensPrompt, TextPrompt]
......
...@@ -379,14 +379,18 @@ class OpenAIServing: ...@@ -379,14 +379,18 @@ class OpenAIServing:
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[TokensPrompt]]: list[TokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tool_dicts,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future = parse_chat_messages_futures(
messages, messages,
self.model_config, model_config,
tokenizer, tokenizer,
content_format=resolved_content_format, content_format=resolved_content_format,
) )
...@@ -410,6 +414,7 @@ class OpenAIServing: ...@@ -410,6 +414,7 @@ class OpenAIServing:
else: else:
request_prompt = apply_hf_chat_template( request_prompt = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment