Unverified Commit d117a4d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Introduce Renderer for processing chat messages (using `ModelConfig`) (#30200)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 421012b6
......@@ -71,6 +71,7 @@ steps:
- tests/test_inputs.py
- tests/test_outputs.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
......@@ -82,6 +83,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
......
......@@ -64,6 +64,7 @@ steps:
- tests/test_inputs.py
- tests/test_outputs.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
......@@ -75,6 +76,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
......
......@@ -121,6 +121,7 @@ steps:
- tests/test_inputs.py
- tests/test_outputs.py
- tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/tool_parsers
......@@ -132,6 +133,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_
- pytest -v -s tool_parsers
- pytest -v -s transformers_utils
......
......@@ -254,7 +254,8 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
# define a reasoning parser and register it to vllm
# the name list in register_module can be used
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tokenizers import get_tokenizer
from ...models.registry import HF_EXAMPLE_MODELS
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATION_OUTPUT = [
(
"facebook/opt-125m",
chatml_jinja_path,
True,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
True,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of""",
),
]
TEST_MESSAGES = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What is the capital of"},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
def test_load_chat_template():
# Testing chatml template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert (
template_content
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
)
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
template_content = load_chat_template(chat_template=template)
assert template_content == template
@pytest.mark.parametrize(
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATION_OUTPUT,
)
def test_get_gen_prompt(
model, template, add_generation_prompt, continue_final_message, expected_output
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Initialize the tokenizer
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message
else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
model_config=model_config,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)
# Test assertion
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}"
)
......@@ -11,7 +11,7 @@ import pytest_asyncio
from openai import OpenAI
from vllm._aiter_ops import is_aiter_found_and_supported
from vllm.config.multimodal import MultiModalConfig
from vllm.config import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
......@@ -23,8 +23,13 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
from vllm.renderers.mistral import MistralRenderer
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM
......@@ -103,15 +108,16 @@ def gptoss_server(default_server_args: list[str]):
@pytest.fixture(scope="class")
def gptoss_speculative_server(default_server_args: list[str]):
attention_backend = (
"TRITON_ATTN"
if not is_aiter_found_and_supported()
else "ROCM_AITER_UNIFIED_ATTN"
)
server_args = default_server_args + [
"--speculative-config",
f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", '
f'"method": "eagle3", "num_speculative_tokens": 3}}',
f"--attention-backend={
'TRITON_ATTN'
if not is_aiter_found_and_supported()
else 'ROCM_AITER_UNIFIED_ATTN'
}",
f"--attention-backend={attention_backend}",
]
# gpt-oss requires AITER unified attention on ROCm
# TODO: Remove after fixing TRITON_ATTN issue on ROCm
......@@ -520,12 +526,21 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
skip_tokenizer_init: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
......@@ -561,6 +576,7 @@ class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig)
input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock)
renderer: MagicMock = field(default_factory=MagicMock)
async def _async_serving_chat_init():
......@@ -586,11 +602,11 @@ def test_async_serving_chat_init():
@pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
messages = [{"role": "user", "content": "what is 1+1?"}]
......@@ -616,11 +632,11 @@ async def test_serving_chat_returns_correct_model_name():
@pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
......@@ -649,11 +665,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine)
......@@ -694,11 +710,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
# Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine)
......@@ -732,42 +748,32 @@ async def test_serving_chat_should_set_correct_max_tokens():
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module):
async def test_serving_chat_mistral_token_ids_prompt_is_validated():
"""Regression test: when the Mistral tokenizer path returns token IDs
directly, we must still apply input length + max_tokens validation.
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
# Only used for logging/validation error messages.
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
# Patch the OpenAI engine serving module to treat our dummy tokenizer
# as a MistralTokenizer. This forces the code path where chat template
# rendering can return a list[int] (token IDs).
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
# Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window.
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(95))
mock_renderer.render_messages_async = AsyncMock(
return_value=(
[],
TokensPrompt(prompt_token_ids=list(range(95))),
)
)
mock_engine.renderer = mock_renderer
serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest(
model=MODEL_NAME,
......@@ -781,39 +787,33 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_mo
@pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(
monkeypatch_module,
):
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
"""Regression test: MistralTokenizer token-id prompts must still enforce
the max context length for the input itself (token_num >= max_model_len).
"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer:
def decode(self, token_ids):
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
# prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest).
serving_chat._apply_mistral_chat_template_async = AsyncMock(
return_value=list(range(mock_engine.model_config.max_model_len))
mock_renderer.render_messages_async = AsyncMock(
return_value=(
[],
TokensPrompt(
prompt_token_ids=list(range(mock_engine.model_config.max_model_len))
),
)
)
mock_engine.renderer = mock_renderer
serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest(
model=MODEL_NAME,
......@@ -835,11 +835,11 @@ async def test_serving_chat_could_load_correct_generation_config():
}
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine)
......@@ -881,11 +881,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_model_config.hf_config.model_type = model_type
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
......@@ -914,11 +914,11 @@ async def test_serving_chat_data_parallel_rank_extraction():
"""Test that data_parallel_rank is properly extracted from header and
passed to engine."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Mock the generate method to return an async generator
async def mock_generate(*args, **kwargs):
......
......@@ -35,6 +35,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_engine_client.model_config = mock_model_config
mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock()
mock_engine_client.renderer = MagicMock()
serving_models = OpenAIServingModels(
engine_client=mock_engine_client,
......
......@@ -131,6 +131,7 @@ class TestInitializeToolSessions:
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
......@@ -217,6 +218,7 @@ class TestValidateGeneratorInput:
engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock()
......
......@@ -212,7 +212,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
return_value="test querytest doc",
),
):
......@@ -245,7 +245,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
......@@ -296,7 +296,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template,
),
patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
......@@ -331,7 +331,7 @@ class TestGetScorePrompt:
return_value=mock_model_with_score_template,
),
patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"),
),
):
......
......@@ -7,21 +7,14 @@ from typing import Literal
import pytest
import torch
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
_try_extract_ast,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template,
parse_chat_messages_async,
)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (
......@@ -29,24 +22,11 @@ from vllm.multimodal.utils import (
encode_image_url,
encode_video_url,
)
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.serial_utils import tensor2base64
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
......@@ -469,7 +449,7 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
image_url,
):
image_uuid = str(hash(image_url))
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -490,7 +470,7 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
assert conversation == [
{"role": "user", "content": "<|image_1|>\nWhat's in the image?"}
]
_assert_mm_data_is_image_input(await mm_future, 1)
_assert_mm_data_is_image_input(mm_data, 1)
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
......@@ -500,7 +480,7 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
image_url,
):
image_uuid = str(hash(image_url))
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -521,7 +501,7 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
assert conversation == [
{"role": "user", "content": "<|image_1|>\nWhat's in the image?"}
]
_assert_mm_data_is_image_input(await mm_future, 1, skipped_image_indices=[0])
_assert_mm_data_is_image_input(mm_data, 1, skipped_image_indices=[0])
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[image_uuid])
......@@ -533,7 +513,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
image_uuid1 = "my_uuid_1"
image_uuid2 = "my_uuid_2"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -562,7 +542,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
"content": "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}
]
_assert_mm_data_is_image_input(await mm_future, 2)
_assert_mm_data_is_image_input(mm_data, 2)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
......@@ -574,7 +554,7 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
image_uuid1 = "my_uuid_1"
image_uuid2 = "my_uuid_2"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -603,7 +583,7 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
"content": "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}
]
_assert_mm_data_is_image_input(await mm_future, 2, skipped_image_indices=[0, 1])
_assert_mm_data_is_image_input(mm_data, 2, skipped_image_indices=[0, 1])
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid1, image_uuid2])
......@@ -614,7 +594,7 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
):
image_uuid2 = "my_uuid_2"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -642,7 +622,7 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
"content": "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}
]
_assert_mm_data_is_image_input(await mm_future, 2)
_assert_mm_data_is_image_input(mm_data, 2)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, image_uuid2])
......@@ -689,7 +669,7 @@ async def test_parse_chat_messages_single_image_async(
phi3v_model_config,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -706,7 +686,7 @@ async def test_parse_chat_messages_single_image_async(
assert conversation == [
{"role": "user", "content": "<|image_1|>\nWhat's in the image?"}
]
_assert_mm_data_is_image_input(await mm_future, 1)
_assert_mm_data_is_image_input(mm_data, 1)
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
......@@ -890,7 +870,7 @@ async def test_parse_chat_messages_audio_embeds_async(
# Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -908,7 +888,6 @@ async def test_parse_chat_messages_audio_embeds_async(
)
# Should have audio embedding in mm_data (single tensor, not a list)
mm_data = await mm_future
assert mm_data is not None
assert "audio" in mm_data
assert isinstance(mm_data["audio"], torch.Tensor)
......@@ -1050,7 +1029,7 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -1080,7 +1059,6 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
]
# Await the future and verify mm_data
mm_data = await mm_future
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
......@@ -1101,7 +1079,7 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds,
):
uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -1121,7 +1099,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
"content": "<|image_1|>\nWhat's in this image?",
}
]
mm_data = await mm_future
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
......@@ -1228,7 +1205,7 @@ async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -1252,7 +1229,7 @@ async def test_parse_chat_messages_multiple_images_async(
"content": "<|image_1|>\n<|image_2|>\nWhat's in these images?",
}
]
_assert_mm_data_is_image_input(await mm_future, 2)
_assert_mm_data_is_image_input(mm_data, 2)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
......@@ -1582,7 +1559,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
phi3v_model_config_mm_interleaved,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -1609,7 +1586,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
"Do they have differences?",
}
]
_assert_mm_data_is_image_input(await mm_data, 2)
_assert_mm_data_is_image_input(mm_data, 2)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
......@@ -1619,7 +1596,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
image_url,
):
image_uuid = str(hash(image_url))
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -1654,7 +1631,7 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
"Do they have differences?",
}
]
_assert_mm_data_is_image_input(await mm_data, 2)
_assert_mm_data_is_image_input(mm_data, 2)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[image_uuid, image_uuid])
......@@ -2030,377 +2007,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
)
@pytest.mark.parametrize(
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
],
)
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
tools = (
[
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
if use_tools
else None
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model, expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = [
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
"tokenize": True,
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
with pytest.raises(
ValueError, match="Found unexpected chat template kwargs from request"
):
# should raise error if `chat_template_kwargs` contains
# `chat_template` or `tokenize`
resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
raise_on_unexpected=False,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# Additional test: Verify HF base parameters work with **kwargs tokenizers
# This validates the fix for tokenizers like Kimi K2 that use **kwargs
# to receive standard HuggingFace parameters instead of declaring them explicitly
from vllm.entrypoints.chat_utils import _get_hf_base_chat_template_params
hf_base_params = _get_hf_base_chat_template_params()
# Verify common HF parameters are in the base class
assert {"add_generation_prompt", "tools", "continue_final_message"}.issubset(
hf_base_params
), f"Expected HF base params not found in {hf_base_params}"
# Test with a mock tokenizer that uses **kwargs (like Kimi K2)
class MockTokenizerWithKwargs:
def apply_chat_template(self, conversation, **kwargs):
return "mocked_output"
mock_tokenizer = MockTokenizerWithKwargs()
mock_kwargs = {
"add_generation_prompt": True,
"tools": tools,
"continue_final_message": False,
"unknown_param": "should_be_filtered",
}
resolved_mock = resolve_chat_template_kwargs(
mock_tokenizer, chat_template, mock_kwargs, raise_on_unexpected=False
)
# HF base params should pass through even with **kwargs tokenizer
assert "add_generation_prompt" in resolved_mock
assert "tools" in resolved_mock
assert "continue_final_message" in resolved_mock
# Unknown params should be filtered out
assert "unknown_param" not in resolved_mock
# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
@pytest.mark.parametrize(
("model", "expected_format"),
[
(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(QWEN2AUDIO_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai"),
],
)
def test_resolve_content_format_hf_defined(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("model", "expected_format"),
[
("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string"),
],
)
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("template_path", "expected_format"),
[
("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_teleflm.jinja", "string"),
("template_vlm2vec_phi3v.jinja", "openai"),
("template_vlm2vec_qwen2vl.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
("tool_chat_template_internlm2_tool.jinja", "string"),
("tool_chat_template_llama3.1_json.jinja", "openai"),
("tool_chat_template_llama3.2_json.jinja", "openai"),
("tool_chat_template_mistral_parallel.jinja", "string"),
("tool_chat_template_mistral.jinja", "string"),
],
)
def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
dummy_tokenizer = get_tokenizer(
PHI3V_MODEL_ID, # Dummy
trust_remote_code=model_config.trust_remote_code,
)
dummy_tokenizer.chat_template = None
chat_template = load_chat_template(EXAMPLES_DIR / template_path)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
chat_template,
None,
"auto",
dummy_tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
messages = [
{
......@@ -2462,56 +2068,6 @@ def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
assert conversation_with_thinking == expected_conversation
def test_apply_mistral_chat_template_thinking_chunk():
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{
"type": "thinking",
"closed": True,
"thinking": "Only return the answer when you are confident.",
},
],
},
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think about it."},
{"type": "thinking", "closed": True, "thinking": "2+2 = 4"},
{
"type": "text",
"text": "The answer is 4.",
},
],
},
{"role": "user", "content": "Thanks, what is 3+3?"},
]
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Magistral-Small-2509"
)
tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None
)
string_tokens = mistral_tokenizer.mistral.decode(
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
expected_tokens = (
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r"[INST]What is 2+2?[/INST]"
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r"[INST]Thanks, what is 3+3?[/INST]"
)
assert string_tokens == expected_tokens
def test_parse_chat_messages_single_empty_audio_with_uuid(
qwen2_audio_model_config,
):
......@@ -2550,7 +2106,7 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
qwen2_audio_model_config,
):
audio_uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
[
{
"role": "user",
......@@ -2575,5 +2131,5 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
"audio say?",
}
]
_assert_mm_data_inputs(await mm_future, {"audio": 1})
_assert_mm_data_inputs(mm_data, {"audio": 1})
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[audio_uuid])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.renderers.hf import (
_get_hf_base_chat_template_params,
_try_extract_ast,
resolve_chat_template,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
safe_apply_chat_template,
)
from vllm.tokenizers import get_tokenizer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATION_OUTPUT = [
(
"facebook/opt-125m",
chatml_jinja_path,
True,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
True,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of""",
),
]
TEST_MESSAGES = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What is the capital of"},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
def test_load_chat_template():
# Testing chatml template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert (
template_content
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
)
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
template_content = load_chat_template(chat_template=template)
assert template_content == template
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-VL-2B-Instruct", # chat_template is of type str
"NousResearch/Hermes-3-Llama-3.1-8B", # chat_template is of type dict
],
)
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
tools = (
[
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
if use_tools
else None
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
"Qwen/Qwen2-VL-2B-Instruct",
{
"add_vision_id",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
(
"Qwen/Qwen3-8B",
{
"enable_thinking",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
],
)
def test_resolve_chat_template_kwargs(sample_json_schema, model, expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = [
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
"tokenize": True,
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
with pytest.raises(
ValueError, match="Found unexpected chat template kwargs from request"
):
# should raise error if `chat_template_kwargs` contains
# `chat_template` or `tokenize`
resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
raise_on_unexpected=False,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# Additional test: Verify HF base parameters work with **kwargs tokenizers
# This validates the fix for tokenizers like Kimi K2 that use **kwargs
# to receive standard HuggingFace parameters instead of declaring them explicitly
hf_base_params = _get_hf_base_chat_template_params()
# Verify common HF parameters are in the base class
assert {"add_generation_prompt", "tools", "continue_final_message"}.issubset(
hf_base_params
), f"Expected HF base params not found in {hf_base_params}"
# Test with a mock tokenizer that uses **kwargs (like Kimi K2)
class MockTokenizerWithKwargs:
def apply_chat_template(self, conversation, **kwargs):
return "mocked_output"
mock_tokenizer = MockTokenizerWithKwargs()
mock_kwargs = {
"add_generation_prompt": True,
"tools": tools,
"continue_final_message": False,
"unknown_param": "should_be_filtered",
}
resolved_mock = resolve_chat_template_kwargs(
mock_tokenizer, chat_template, mock_kwargs, raise_on_unexpected=False
)
# HF base params should pass through even with **kwargs tokenizer
assert "add_generation_prompt" in resolved_mock
assert "tools" in resolved_mock
assert "continue_final_message" in resolved_mock
# Unknown params should be filtered out
assert "unknown_param" not in resolved_mock
# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
@pytest.mark.parametrize(
("model", "expected_format"),
[
("microsoft/Phi-3.5-vision-instruct", "string"),
("Qwen/Qwen2-VL-2B-Instruct", "openai"),
("Qwen/Qwen2.5-VL-3B-Instruct", "openai"),
("fixie-ai/ultravox-v0_5-llama-3_2-1b", "string"),
("Qwen/Qwen2-Audio-7B-Instruct", "openai"),
("meta-llama/Llama-Guard-3-1B", "openai"),
],
)
def test_resolve_content_format_hf_defined(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("model", "expected_format"),
[
("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string"),
],
)
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("template_path", "expected_format"),
[
("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_teleflm.jinja", "string"),
("template_vlm2vec_phi3v.jinja", "openai"),
("template_vlm2vec_qwen2vl.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
("tool_chat_template_internlm2_tool.jinja", "string"),
("tool_chat_template_llama3.1_json.jinja", "openai"),
("tool_chat_template_llama3.2_json.jinja", "openai"),
("tool_chat_template_mistral_parallel.jinja", "string"),
("tool_chat_template_mistral.jinja", "string"),
],
)
def test_resolve_content_format_examples(template_path, expected_format):
model = "Qwen/Qwen2-VL-2B-Instruct" # Dummy
model_config = ModelConfig(
model,
tokenizer=model,
trust_remote_code=True,
)
dummy_tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
dummy_tokenizer.chat_template = None
chat_template = load_chat_template(EXAMPLES_DIR / template_path)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
chat_template,
None,
"auto",
dummy_tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATION_OUTPUT,
)
def test_get_gen_prompt(
model, template, add_generation_prompt, continue_final_message, expected_output
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Initialize the tokenizer
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message
else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = safe_apply_chat_template(
model_config,
tokenizer,
mock_request.messages,
tools=None,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
tokenize=False,
)
# Test assertion
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}"
)
......@@ -6,38 +6,15 @@ import time
from unittest.mock import Mock
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer
@pytest.fixture()
def serving() -> OpenAIServing:
"""Create a minimal OpenAIServing instance for testing."""
# Create minimal mocks
engine_client = Mock()
model_config = Mock(spec=ModelConfig)
model_config.max_model_len = 32768
models = Mock(spec=OpenAIServingModels)
models.model_config = model_config
models.input_processor = Mock()
models.io_processor = Mock()
serving = OpenAIServing(
engine_client=engine_client,
models=models,
request_logger=None,
)
return serving
@pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop(
serving: OpenAIServing,
):
async def test_async_mistral_tokenizer_does_not_block_event_loop():
expected_tokens = [1, 2, 3]
# Mock the blocking version to sleep
......@@ -46,11 +23,11 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
return expected_tokens
mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = serving._apply_mistral_chat_template_async(
tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[]
)
task = mock_renderer.render_messages_async([])
# Ensure the event loop is not blocked
blocked_count = 0
......@@ -66,6 +43,58 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
await asyncio.sleep(0.1)
# Ensure task completes
tokens = await task
assert tokens == expected_tokens, "Mocked blocking tokenizer was not called"
_, prompt = await task
assert prompt["prompt_token_ids"] == expected_tokens, (
"Mocked blocking tokenizer was not called"
)
assert blocked_count == 0, "Event loop blocked during tokenization"
def test_apply_mistral_chat_template_thinking_chunk():
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{
"type": "thinking",
"closed": True,
"thinking": "Only return the answer when you are confident.",
},
],
},
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think about it."},
{"type": "thinking", "closed": True, "thinking": "2+2 = 4"},
{
"type": "text",
"text": "The answer is 4.",
},
],
},
{"role": "user", "content": "Thanks, what is 3+3?"},
]
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Magistral-Small-2509"
)
tokens_ids = safe_apply_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None
)
string_tokens = mistral_tokenizer.mistral.decode(
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
expected_tokens = (
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r"[INST]What is 2+2?[/INST]"
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r"[INST]Thanks, what is 3+3?[/INST]"
)
assert string_tokens == expected_tokens
......@@ -7,7 +7,6 @@ from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor
from vllm.tokenizers import cached_tokenizer_from_config
pytestmark = pytest.mark.cpu_test
......@@ -115,10 +114,10 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
)
def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id)
tokenizer = cached_tokenizer_from_config(model_config)
input_preprocessor = InputPreprocessor(model_config, tokenizer)
input_preprocessor = InputPreprocessor(model_config)
# HF processor adds sep token
tokenizer = input_preprocessor.get_tokenizer()
sep_token_id = tokenizer.vocab[tokenizer.sep_token]
processed_inputs = input_preprocessor.preprocess(prompt)
......
......@@ -224,7 +224,7 @@ def test_skip_tokenizer_initialization(model: str):
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
llm.generate("abc", sampling_params)
outputs = llm.generate(
......
......@@ -5,7 +5,13 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
MultiModalConfig,
VllmConfig,
)
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor
......@@ -44,27 +50,22 @@ def _mock_input_processor(
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
model_config = ModelConfig(
tokenizer="dummy",
skip_tokenizer_init=True,
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm",
tokenizer="dummy",
)
model_config.runner_type = "generate"
model_config.multimodal_config = MultiModalConfig(mm_processor_cache_gb=mm_cache_gb)
# Minimal multimodal_config to satisfy references in
# Processor.process_inputs.
class _MockMMConfig:
def __init__(self, gb: float):
self.mm_processor_cache_gb = gb
model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined]
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"),
)
return InputProcessor(vllm_config, tokenizer=None)
return InputProcessor(vllm_config)
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
......
......@@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/renderers",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
......
......@@ -11,9 +11,9 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
......@@ -26,6 +26,10 @@ class EngineClient(ABC):
input_processor: InputProcessor
io_processor: IOProcessor | None
@property
@abstractmethod
def renderer(self) -> RendererLike: ...
@property
@abstractmethod
def is_running(self) -> bool: ...
......@@ -88,11 +92,6 @@ class EngineClient(ABC):
"""
...
@abstractmethod
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool: ...
......
......@@ -2,22 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
......@@ -39,7 +32,6 @@ from openai.types.responses import ResponseInputImageParam
from openai_harmony import Message as OpenAIHarmonyMessage
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypedDict
......@@ -50,24 +42,35 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import random_uuid
from vllm.utils.collection_utils import is_list_of
from vllm.utils.func_utils import supports_kw
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
from vllm.tokenizers.mistral import MistralTokenizer
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
def __getattr__(name: str):
if name == "resolve_hf_chat_template":
from vllm.renderers.hf import resolve_chat_template
warnings.warn(
"`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
"`vllm.renderers.hf.resolve_chat_template`. "
"The old name will be removed in v0.16.",
DeprecationWarning,
stacklevel=2,
)
return resolve_chat_template
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
class ChatTemplateResolutionError(ValueError):
"""Raised when chat template resolution fails.
......@@ -320,325 +323,8 @@ class ConversationMessage(TypedDict, total=False):
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str | None = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def resolve_mistral_chat_template(
chat_template: str | None,
**kwargs: Any,
) -> str | None:
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
raise ValueError(
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
"for mistral tokenizer."
)
return None
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.
This is needed because `lru_cache` does not cache when an exception happens.
"""
def _try_get_processor_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
model_config: ModelConfig,
) -> str | None:
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=model_config.trust_remote_code,
)
if (
isinstance(processor, ProcessorMixin)
and hasattr(processor, "chat_template")
and (chat_template := processor.chat_template) is not None
):
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
return chat_template
except Exception:
logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
return None
def resolve_hf_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
) -> str | None:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
if chat_template is not None:
return chat_template
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug(
"Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info_once(
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path)
else:
logger.debug_once(
"There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
else:
hf_chat_template = None
jinja_text = (
hf_chat_template
if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = (
"string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format
@lru_cache
def _log_chat_template_content_format(
chat_template: str | None,
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if given_format != "auto":
return given_format
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
tokenizer,
model_config=model_config,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
......@@ -1593,7 +1279,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
interleave_strings: bool,
) -> list[ConversationMessage]:
role = message["role"]
......@@ -1669,7 +1355,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
MultiModalDataDict | None,
......@@ -1697,13 +1383,13 @@ def parse_chat_messages(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def parse_chat_messages_futures(
async def parse_chat_messages_async(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
Awaitable[MultiModalDataDict | None],
MultiModalDataDict | None,
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
......@@ -1725,174 +1411,7 @@ def parse_chat_messages_futures(
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def _resolve_chat_template_kwargs(
chat_template: str,
):
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
return template_vars
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
# Get standard parameters from HuggingFace's base tokenizer class.
# This dynamically extracts parameters from PreTrainedTokenizer's
# apply_chat_template method, ensuring compatibility with tokenizers
# that use **kwargs to receive standard parameters.
# Read signature from HF's base class - the single source of truth
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
return frozenset(
p.name
for p in base_sig.parameters.values()
if p.kind
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
)
def resolve_chat_template_kwargs(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str,
chat_template_kwargs: dict[str, Any],
raise_on_unexpected: bool = True,
) -> dict[str, Any]:
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template", "tokenize"}
if raise_on_unexpected and (
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
):
raise ValueError(
"Found unexpected chat template kwargs from request: "
f"{unexpected_in_kwargs}"
)
fn_kw = {
k
for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
hf_base_params = _get_hf_base_chat_template_params()
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
def apply_hf_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
conversation: list[ConversationMessage],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
if hf_chat_template is None:
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=False,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e
def apply_mistral_chat_template(
tokenizer: "MistralTokenizer",
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
**kwargs: Any,
) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None,
# and we won't use it.
resolve_mistral_chat_template(
chat_template=chat_template,
**kwargs,
)
try:
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
......
......@@ -37,10 +37,6 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format,
)
from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
......@@ -786,7 +782,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TokensPrompt]:
) -> list[TextPrompt | TokensPrompt]:
"""
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
......@@ -807,63 +803,27 @@ class LLM:
# messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
tokenizer = self.get_tokenizer()
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tools,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
renderer = self.llm_engine.renderer
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
_chat_template_kwargs.update(chat_template_kwargs or {})
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tools,
**(chat_template_kwargs or {}),
}
prompts: list[TokensPrompt] = []
prompts = list[TextPrompt | TokensPrompt]()
for msgs in list_of_messages:
# NOTE: _parse_chat_message_content_parts() currently doesn't
# NOTE: renderer.render_messages() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data, mm_uuids = parse_chat_messages(
_, prompt = renderer.render_messages(
msgs,
model_config,
content_format=resolved_content_format,
chat_template_content_format=chat_template_content_format,
**chat_template_kwargs,
)
if isinstance(tokenizer, MistralTokenizer):
prompt_token_ids = apply_mistral_chat_template(
tokenizer,
messages=msgs,
**_chat_template_kwargs,
)
else:
prompt_str = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
prompt_token_ids = tokenizer.encode(
prompt_str, add_special_tokens=False
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment