Unverified Commit d117a4d1 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Introduce Renderer for processing chat messages (using `ModelConfig`) (#30200)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 421012b6
...@@ -71,6 +71,7 @@ steps: ...@@ -71,6 +71,7 @@ steps:
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_ - tests/tokenizers_
- tests/tool_parsers - tests/tool_parsers
...@@ -82,6 +83,7 @@ steps: ...@@ -82,6 +83,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_
- pytest -v -s tool_parsers - pytest -v -s tool_parsers
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
......
...@@ -64,6 +64,7 @@ steps: ...@@ -64,6 +64,7 @@ steps:
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_ - tests/tokenizers_
- tests/tool_parsers - tests/tool_parsers
...@@ -75,6 +76,7 @@ steps: ...@@ -75,6 +76,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_
- pytest -v -s tool_parsers - pytest -v -s tool_parsers
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
......
...@@ -121,6 +121,7 @@ steps: ...@@ -121,6 +121,7 @@ steps:
- tests/test_inputs.py - tests/test_inputs.py
- tests/test_outputs.py - tests/test_outputs.py
- tests/multimodal - tests/multimodal
- tests/renderers
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_imports.py
- tests/tokenizers_ - tests/tokenizers_
- tests/tool_parsers - tests/tool_parsers
...@@ -132,6 +133,7 @@ steps: ...@@ -132,6 +133,7 @@ steps:
- pytest -v -s test_inputs.py - pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py - pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal - pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s renderers
- pytest -v -s tokenizers_ - pytest -v -s tokenizers_
- pytest -v -s tool_parsers - pytest -v -s tool_parsers
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
......
...@@ -254,7 +254,8 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso ...@@ -254,7 +254,8 @@ You can add a new `ReasoningParser` similar to [vllm/reasoning/deepseek_r1_reaso
# import the required packages # import the required packages
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
# define a reasoning parser and register it to vllm # define a reasoning parser and register it to vllm
# the name list in register_module can be used # the name list in register_module can be used
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tokenizers import get_tokenizer
from ...models.registry import HF_EXAMPLE_MODELS
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATION_OUTPUT = [
(
"facebook/opt-125m",
chatml_jinja_path,
True,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
True,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of""",
),
]
TEST_MESSAGES = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What is the capital of"},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
def test_load_chat_template():
# Testing chatml template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert (
template_content
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
)
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
template_content = load_chat_template(chat_template=template)
assert template_content == template
@pytest.mark.parametrize(
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATION_OUTPUT,
)
def test_get_gen_prompt(
model, template, add_generation_prompt, continue_final_message, expected_output
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Initialize the tokenizer
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message
else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
model_config=model_config,
tools=None,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)
# Test assertion
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}"
)
...@@ -11,7 +11,7 @@ import pytest_asyncio ...@@ -11,7 +11,7 @@ import pytest_asyncio
from openai import OpenAI from openai import OpenAI
from vllm._aiter_ops import is_aiter_found_and_supported from vllm._aiter_ops import is_aiter_found_and_supported
from vllm.config.multimodal import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
...@@ -23,8 +23,13 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,8 +23,13 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.inputs import TokensPrompt
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.renderers.hf import HfRenderer
from vllm.renderers.mistral import MistralRenderer
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.tool_parsers import ToolParserManager from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
...@@ -103,15 +108,16 @@ def gptoss_server(default_server_args: list[str]): ...@@ -103,15 +108,16 @@ def gptoss_server(default_server_args: list[str]):
@pytest.fixture(scope="class") @pytest.fixture(scope="class")
def gptoss_speculative_server(default_server_args: list[str]): def gptoss_speculative_server(default_server_args: list[str]):
attention_backend = (
"TRITON_ATTN"
if not is_aiter_found_and_supported()
else "ROCM_AITER_UNIFIED_ATTN"
)
server_args = default_server_args + [ server_args = default_server_args + [
"--speculative-config", "--speculative-config",
f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", ' f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", '
f'"method": "eagle3", "num_speculative_tokens": 3}}', f'"method": "eagle3", "num_speculative_tokens": 3}}',
f"--attention-backend={ f"--attention-backend={attention_backend}",
'TRITON_ATTN'
if not is_aiter_found_and_supported()
else 'ROCM_AITER_UNIFIED_ATTN'
}",
] ]
# gpt-oss requires AITER unified attention on ROCm # gpt-oss requires AITER unified attention on ROCm
# TODO: Remove after fixing TRITON_ATTN issue on ROCm # TODO: Remove after fixing TRITON_ATTN issue on ROCm
...@@ -520,12 +526,21 @@ class MockModelConfig: ...@@ -520,12 +526,21 @@ class MockModelConfig:
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels( models = OpenAIServingModels(
engine_client=engine, engine_client=engine,
...@@ -561,6 +576,7 @@ class MockEngine: ...@@ -561,6 +576,7 @@ class MockEngine:
model_config: MockModelConfig = field(default_factory=MockModelConfig) model_config: MockModelConfig = field(default_factory=MockModelConfig)
input_processor: MagicMock = field(default_factory=MagicMock) input_processor: MagicMock = field(default_factory=MagicMock)
io_processor: MagicMock = field(default_factory=MagicMock) io_processor: MagicMock = field(default_factory=MagicMock)
renderer: MagicMock = field(default_factory=MagicMock)
async def _async_serving_chat_init(): async def _async_serving_chat_init():
...@@ -586,11 +602,11 @@ def test_async_serving_chat_init(): ...@@ -586,11 +602,11 @@ def test_async_serving_chat_init():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_returns_correct_model_name(): async def test_serving_chat_returns_correct_model_name():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
messages = [{"role": "user", "content": "what is 1+1?"}] messages = [{"role": "user", "content": "what is 1+1?"}]
...@@ -616,11 +632,11 @@ async def test_serving_chat_returns_correct_model_name(): ...@@ -616,11 +632,11 @@ async def test_serving_chat_returns_correct_model_name():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_should_set_correct_max_tokens(): async def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -649,11 +665,11 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -649,11 +665,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
# Reinitialize the engine with new settings # Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -694,11 +710,11 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -694,11 +710,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
# Reinitialize the engine with new settings # Reinitialize the engine with new settings
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -732,42 +748,32 @@ async def test_serving_chat_should_set_correct_max_tokens(): ...@@ -732,42 +748,32 @@ async def test_serving_chat_should_set_correct_max_tokens():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module): async def test_serving_chat_mistral_token_ids_prompt_is_validated():
"""Regression test: when the Mistral tokenizer path returns token IDs """Regression test: when the Mistral tokenizer path returns token IDs
directly, we must still apply input length + max_tokens validation. directly, we must still apply input length + max_tokens validation.
""" """
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer: mock_tokenizer = MagicMock(spec=MistralTokenizer)
def decode(self, token_ids): mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
# Only used for logging/validation error messages. mock_renderer._tokenizer = mock_tokenizer
return "dummy"
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
# Patch the OpenAI engine serving module to treat our dummy tokenizer
# as a MistralTokenizer. This forces the code path where chat template
# rendering can return a list[int] (token IDs).
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# Force the Mistral chat template renderer to return token IDs. # Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that # Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window. # adding max_tokens should exceed the model context window.
serving_chat._apply_mistral_chat_template_async = AsyncMock( mock_renderer.render_messages_async = AsyncMock(
return_value=list(range(95)) return_value=(
[],
TokensPrompt(prompt_token_ids=list(range(95))),
)
) )
mock_engine.renderer = mock_renderer
serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
...@@ -781,39 +787,33 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_mo ...@@ -781,39 +787,33 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_mo
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected( async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
monkeypatch_module,
):
"""Regression test: MistralTokenizer token-id prompts must still enforce """Regression test: MistralTokenizer token-id prompts must still enforce
the max context length for the input itself (token_num >= max_model_len). the max context length for the input itself (token_num >= max_model_len).
""" """
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
class DummyMistralTokenizer: mock_tokenizer = MagicMock(spec=MistralTokenizer)
def decode(self, token_ids): mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
return "dummy" mock_renderer._tokenizer = mock_tokenizer
dummy_tokenizer = DummyMistralTokenizer()
mock_engine.get_tokenizer.return_value = dummy_tokenizer
import vllm.entrypoints.openai.engine.serving as engine_serving
monkeypatch_module.setattr(
engine_serving, "MistralTokenizer", DummyMistralTokenizer
)
serving_chat = _build_serving_chat(mock_engine)
# prompt_token_ids length == max_model_len should be rejected for # prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest). # completion-like requests (ChatCompletionRequest).
serving_chat._apply_mistral_chat_template_async = AsyncMock( mock_renderer.render_messages_async = AsyncMock(
return_value=list(range(mock_engine.model_config.max_model_len)) return_value=(
[],
TokensPrompt(
prompt_token_ids=list(range(mock_engine.model_config.max_model_len))
),
)
) )
mock_engine.renderer = mock_renderer
serving_chat = _build_serving_chat(mock_engine)
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=MODEL_NAME, model=MODEL_NAME,
...@@ -835,11 +835,11 @@ async def test_serving_chat_could_load_correct_generation_config(): ...@@ -835,11 +835,11 @@ async def test_serving_chat_could_load_correct_generation_config():
} }
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Initialize the serving chat # Initialize the serving chat
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -881,11 +881,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): ...@@ -881,11 +881,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_model_config.hf_config.model_type = model_type mock_model_config.hf_config.model_type = model_type
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = mock_model_config mock_engine.model_config = mock_model_config
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine) serving_chat = _build_serving_chat(mock_engine)
...@@ -914,11 +914,11 @@ async def test_serving_chat_data_parallel_rank_extraction(): ...@@ -914,11 +914,11 @@ async def test_serving_chat_data_parallel_rank_extraction():
"""Test that data_parallel_rank is properly extracted from header and """Test that data_parallel_rank is properly extracted from header and
passed to engine.""" passed to engine."""
mock_engine = MagicMock(spec=AsyncLLM) mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False mock_engine.errored = False
mock_engine.model_config = MockModelConfig() mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock() mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock() mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
# Mock the generate method to return an async generator # Mock the generate method to return an async generator
async def mock_generate(*args, **kwargs): async def mock_generate(*args, **kwargs):
......
...@@ -35,6 +35,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: ...@@ -35,6 +35,7 @@ async def _async_serving_models_init() -> OpenAIServingModels:
mock_engine_client.model_config = mock_model_config mock_engine_client.model_config = mock_model_config
mock_engine_client.input_processor = MagicMock() mock_engine_client.input_processor = MagicMock()
mock_engine_client.io_processor = MagicMock() mock_engine_client.io_processor = MagicMock()
mock_engine_client.renderer = MagicMock()
serving_models = OpenAIServingModels( serving_models = OpenAIServingModels(
engine_client=mock_engine_client, engine_client=mock_engine_client,
......
...@@ -131,6 +131,7 @@ class TestInitializeToolSessions: ...@@ -131,6 +131,7 @@ class TestInitializeToolSessions:
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock() engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock() models = MagicMock()
...@@ -217,6 +218,7 @@ class TestValidateGeneratorInput: ...@@ -217,6 +218,7 @@ class TestValidateGeneratorInput:
engine_client.input_processor = MagicMock() engine_client.input_processor = MagicMock()
engine_client.io_processor = MagicMock() engine_client.io_processor = MagicMock()
engine_client.renderer = MagicMock()
models = MagicMock() models = MagicMock()
......
...@@ -212,7 +212,7 @@ class TestGetScorePrompt: ...@@ -212,7 +212,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
return_value="test querytest doc", return_value="test querytest doc",
), ),
): ):
...@@ -245,7 +245,7 @@ class TestGetScorePrompt: ...@@ -245,7 +245,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
...@@ -296,7 +296,7 @@ class TestGetScorePrompt: ...@@ -296,7 +296,7 @@ class TestGetScorePrompt:
return_value=mock_model_no_score_template, return_value=mock_model_no_score_template,
), ),
patch( patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
...@@ -331,7 +331,7 @@ class TestGetScorePrompt: ...@@ -331,7 +331,7 @@ class TestGetScorePrompt:
return_value=mock_model_with_score_template, return_value=mock_model_with_score_template,
), ),
patch( patch(
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template", "vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
side_effect=ChatTemplateResolutionError("No template"), side_effect=ChatTemplateResolutionError("No template"),
), ),
): ):
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.renderers.hf import (
_get_hf_base_chat_template_params,
_try_extract_ast,
resolve_chat_template,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
safe_apply_chat_template,
)
from vllm.tokenizers import get_tokenizer
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATION_OUTPUT = [
(
"facebook/opt-125m",
chatml_jinja_path,
True,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
False,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""",
),
(
"facebook/opt-125m",
chatml_jinja_path,
False,
True,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of""",
),
]
TEST_MESSAGES = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "What is the capital of"},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
def test_load_chat_template():
# Testing chatml template
template_content = load_chat_template(chat_template=chatml_jinja_path)
# Test assertions
assert template_content is not None
# Hard coded value for template_chatml.jinja
assert (
template_content
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
)
def test_no_load_chat_template_filelike():
# Testing chatml template
template = "../../examples/does_not_exist"
with pytest.raises(ValueError, match="looks like a file path"):
load_chat_template(chat_template=template)
def test_no_load_chat_template_literallike():
# Testing chatml template
template = "{{ messages }}"
template_content = load_chat_template(chat_template=template)
assert template_content == template
@pytest.mark.parametrize(
"model",
[
"Qwen/Qwen2-VL-2B-Instruct", # chat_template is of type str
"NousResearch/Hermes-3-Llama-3.1-8B", # chat_template is of type dict
],
)
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
tools = (
[
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
if use_tools
else None
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
"Qwen/Qwen2-VL-2B-Instruct",
{
"add_vision_id",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
(
"Qwen/Qwen3-8B",
{
"enable_thinking",
"add_generation_prompt",
"continue_final_message",
"tools",
},
),
],
)
def test_resolve_chat_template_kwargs(sample_json_schema, model, expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = [
{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}
]
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
"tokenize": True,
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
with pytest.raises(
ValueError, match="Found unexpected chat template kwargs from request"
):
# should raise error if `chat_template_kwargs` contains
# `chat_template` or `tokenize`
resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
raise_on_unexpected=False,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# Additional test: Verify HF base parameters work with **kwargs tokenizers
# This validates the fix for tokenizers like Kimi K2 that use **kwargs
# to receive standard HuggingFace parameters instead of declaring them explicitly
hf_base_params = _get_hf_base_chat_template_params()
# Verify common HF parameters are in the base class
assert {"add_generation_prompt", "tools", "continue_final_message"}.issubset(
hf_base_params
), f"Expected HF base params not found in {hf_base_params}"
# Test with a mock tokenizer that uses **kwargs (like Kimi K2)
class MockTokenizerWithKwargs:
def apply_chat_template(self, conversation, **kwargs):
return "mocked_output"
mock_tokenizer = MockTokenizerWithKwargs()
mock_kwargs = {
"add_generation_prompt": True,
"tools": tools,
"continue_final_message": False,
"unknown_param": "should_be_filtered",
}
resolved_mock = resolve_chat_template_kwargs(
mock_tokenizer, chat_template, mock_kwargs, raise_on_unexpected=False
)
# HF base params should pass through even with **kwargs tokenizer
assert "add_generation_prompt" in resolved_mock
assert "tools" in resolved_mock
assert "continue_final_message" in resolved_mock
# Unknown params should be filtered out
assert "unknown_param" not in resolved_mock
# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
@pytest.mark.parametrize(
("model", "expected_format"),
[
("microsoft/Phi-3.5-vision-instruct", "string"),
("Qwen/Qwen2-VL-2B-Instruct", "openai"),
("Qwen/Qwen2.5-VL-3B-Instruct", "openai"),
("fixie-ai/ultravox-v0_5-llama-3_2-1b", "string"),
("Qwen/Qwen2-Audio-7B-Instruct", "openai"),
("meta-llama/Llama-Guard-3-1B", "openai"),
],
)
def test_resolve_content_format_hf_defined(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("model", "expected_format"),
[
("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string"),
],
)
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
tokenizer = get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_chat_template(
tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
("template_path", "expected_format"),
[
("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_teleflm.jinja", "string"),
("template_vlm2vec_phi3v.jinja", "openai"),
("template_vlm2vec_qwen2vl.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
("tool_chat_template_internlm2_tool.jinja", "string"),
("tool_chat_template_llama3.1_json.jinja", "openai"),
("tool_chat_template_llama3.2_json.jinja", "openai"),
("tool_chat_template_mistral_parallel.jinja", "string"),
("tool_chat_template_mistral.jinja", "string"),
],
)
def test_resolve_content_format_examples(template_path, expected_format):
model = "Qwen/Qwen2-VL-2B-Instruct" # Dummy
model_config = ModelConfig(
model,
tokenizer=model,
trust_remote_code=True,
)
dummy_tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
dummy_tokenizer.chat_template = None
chat_template = load_chat_template(EXAMPLES_DIR / template_path)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
chat_template,
None,
"auto",
dummy_tokenizer,
model_config=model_config,
)
assert resolved_format == expected_format
@pytest.mark.parametrize(
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATION_OUTPUT,
)
def test_get_gen_prompt(
model, template, add_generation_prompt, continue_final_message, expected_output
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.require_embed_inputs,
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
)
# Initialize the tokenizer
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message
else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = safe_apply_chat_template(
model_config,
tokenizer,
mock_request.messages,
tools=None,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
tokenize=False,
)
# Test assertion
assert result == expected_output, (
f"The generated prompt does not match the expected output for "
f"model {model} and template {template}"
)
...@@ -6,38 +6,15 @@ import time ...@@ -6,38 +6,15 @@ import time
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
@pytest.fixture()
def serving() -> OpenAIServing:
"""Create a minimal OpenAIServing instance for testing."""
# Create minimal mocks
engine_client = Mock()
model_config = Mock(spec=ModelConfig)
model_config.max_model_len = 32768
models = Mock(spec=OpenAIServingModels)
models.model_config = model_config
models.input_processor = Mock()
models.io_processor = Mock()
serving = OpenAIServing(
engine_client=engine_client,
models=models,
request_logger=None,
)
return serving
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop( async def test_async_mistral_tokenizer_does_not_block_event_loop():
serving: OpenAIServing,
):
expected_tokens = [1, 2, 3] expected_tokens = [1, 2, 3]
# Mock the blocking version to sleep # Mock the blocking version to sleep
...@@ -46,11 +23,11 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop( ...@@ -46,11 +23,11 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
return expected_tokens return expected_tokens
mock_tokenizer = Mock(spec=MistralTokenizer) mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template.side_effect = mocked_apply_chat_template mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = serving._apply_mistral_chat_template_async( task = mock_renderer.render_messages_async([])
tokenizer=mock_tokenizer, messages=[], chat_template=None, tools=[]
)
# Ensure the event loop is not blocked # Ensure the event loop is not blocked
blocked_count = 0 blocked_count = 0
...@@ -66,6 +43,58 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop( ...@@ -66,6 +43,58 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# Ensure task completes # Ensure task completes
tokens = await task _, prompt = await task
assert tokens == expected_tokens, "Mocked blocking tokenizer was not called" assert prompt["prompt_token_ids"] == expected_tokens, (
"Mocked blocking tokenizer was not called"
)
assert blocked_count == 0, "Event loop blocked during tokenization" assert blocked_count == 0, "Event loop blocked during tokenization"
def test_apply_mistral_chat_template_thinking_chunk():
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."},
{
"type": "thinking",
"closed": True,
"thinking": "Only return the answer when you are confident.",
},
],
},
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me think about it."},
{"type": "thinking", "closed": True, "thinking": "2+2 = 4"},
{
"type": "text",
"text": "The answer is 4.",
},
],
},
{"role": "user", "content": "Thanks, what is 3+3?"},
]
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Magistral-Small-2509"
)
tokens_ids = safe_apply_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None
)
string_tokens = mistral_tokenizer.mistral.decode(
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP
)
expected_tokens = (
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
r"[INST]What is 2+2?[/INST]"
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
r"[INST]Thanks, what is 3+3?[/INST]"
)
assert string_tokens == expected_tokens
...@@ -7,7 +7,6 @@ from vllm.config import ModelConfig ...@@ -7,7 +7,6 @@ from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.tokenizers import cached_tokenizer_from_config
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
...@@ -115,10 +114,10 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): ...@@ -115,10 +114,10 @@ def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
) )
def test_preprocessor_always_mm_code_path(model_id, prompt): def test_preprocessor_always_mm_code_path(model_id, prompt):
model_config = ModelConfig(model=model_id) model_config = ModelConfig(model=model_id)
tokenizer = cached_tokenizer_from_config(model_config) input_preprocessor = InputPreprocessor(model_config)
input_preprocessor = InputPreprocessor(model_config, tokenizer)
# HF processor adds sep token # HF processor adds sep token
tokenizer = input_preprocessor.get_tokenizer()
sep_token_id = tokenizer.vocab[tokenizer.sep_token] sep_token_id = tokenizer.vocab[tokenizer.sep_token]
processed_inputs = input_preprocessor.preprocess(prompt) processed_inputs = input_preprocessor.preprocess(prompt)
......
...@@ -224,7 +224,7 @@ def test_skip_tokenizer_initialization(model: str): ...@@ -224,7 +224,7 @@ def test_skip_tokenizer_initialization(model: str):
) )
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"): with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
llm.generate("abc", sampling_params) llm.generate("abc", sampling_params)
outputs = llm.generate( outputs = llm.generate(
......
...@@ -5,7 +5,13 @@ import pytest ...@@ -5,7 +5,13 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig from vllm.config import (
CacheConfig,
DeviceConfig,
ModelConfig,
MultiModalConfig,
VllmConfig,
)
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
...@@ -44,27 +50,22 @@ def _mock_input_processor( ...@@ -44,27 +50,22 @@ def _mock_input_processor(
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True) monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
model_config = ModelConfig( model_config = ModelConfig(
tokenizer="dummy",
skip_tokenizer_init=True, skip_tokenizer_init=True,
max_model_len=128, max_model_len=128,
mm_processor_cache_gb=mm_cache_gb, mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm", generation_config="vllm",
tokenizer="dummy",
) )
model_config.runner_type = "generate"
model_config.multimodal_config = MultiModalConfig(mm_processor_cache_gb=mm_cache_gb)
# Minimal multimodal_config to satisfy references in
# Processor.process_inputs.
class _MockMMConfig:
def __init__(self, gb: float):
self.mm_processor_cache_gb = gb
model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined]
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching), cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"), device_config=DeviceConfig(device="cpu"),
) )
return InputProcessor(vllm_config, tokenizer=None) return InputProcessor(vllm_config)
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch): def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
......
...@@ -35,6 +35,7 @@ FILES = [ ...@@ -35,6 +35,7 @@ FILES = [
"vllm/multimodal", "vllm/multimodal",
"vllm/platforms", "vllm/platforms",
"vllm/plugins", "vllm/plugins",
"vllm/renderers",
"vllm/tokenizers", "vllm/tokenizers",
"vllm/transformers_utils", "vllm/transformers_utils",
"vllm/triton_utils", "vllm/triton_utils",
......
...@@ -11,9 +11,9 @@ from vllm.lora.request import LoRARequest ...@@ -11,9 +11,9 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
...@@ -26,6 +26,10 @@ class EngineClient(ABC): ...@@ -26,6 +26,10 @@ class EngineClient(ABC):
input_processor: InputProcessor input_processor: InputProcessor
io_processor: IOProcessor | None io_processor: IOProcessor | None
@property
@abstractmethod
def renderer(self) -> RendererLike: ...
@property @property
@abstractmethod @abstractmethod
def is_running(self) -> bool: ... def is_running(self) -> bool: ...
...@@ -88,11 +92,6 @@ class EngineClient(ABC): ...@@ -88,11 +92,6 @@ class EngineClient(ABC):
""" """
... ...
@abstractmethod
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...
@abstractmethod @abstractmethod
async def is_tracing_enabled(self) -> bool: ... async def is_tracing_enabled(self) -> bool: ...
......
This diff is collapsed.
...@@ -37,10 +37,6 @@ from vllm.engine.arg_utils import EngineArgs ...@@ -37,10 +37,6 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format,
) )
from vllm.entrypoints.pooling.score.utils import ( from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam, ScoreContentPartParam,
...@@ -786,7 +782,7 @@ class LLM: ...@@ -786,7 +782,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TokensPrompt]: ) -> list[TextPrompt | TokensPrompt]:
""" """
Generate prompt for a chat conversation. The pre-processed Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods. prompt can then be used as input for the other LLM methods.
...@@ -807,63 +803,27 @@ class LLM: ...@@ -807,63 +803,27 @@ class LLM:
# messages is list[...] # messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)] list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
tokenizer = self.get_tokenizer() renderer = self.llm_engine.renderer
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tools,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
_chat_template_kwargs: dict[str, Any] = dict( chat_template_kwargs = {
chat_template=chat_template, "chat_template": chat_template,
add_generation_prompt=add_generation_prompt, "add_generation_prompt": add_generation_prompt,
continue_final_message=continue_final_message, "continue_final_message": continue_final_message,
tools=tools, "tools": tools,
) **(chat_template_kwargs or {}),
_chat_template_kwargs.update(chat_template_kwargs or {}) }
prompts: list[TokensPrompt] = [] prompts = list[TextPrompt | TokensPrompt]()
for msgs in list_of_messages: for msgs in list_of_messages:
# NOTE: _parse_chat_message_content_parts() currently doesn't # NOTE: renderer.render_messages() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in # handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it. # the chat message parsing for it.
conversation, mm_data, mm_uuids = parse_chat_messages( _, prompt = renderer.render_messages(
msgs, msgs,
model_config, chat_template_content_format=chat_template_content_format,
content_format=resolved_content_format, **chat_template_kwargs,
) )
if isinstance(tokenizer, MistralTokenizer):
prompt_token_ids = apply_mistral_chat_template(
tokenizer,
messages=msgs,
**_chat_template_kwargs,
)
else:
prompt_str = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
prompt_token_ids = tokenizer.encode(
prompt_str, add_special_tokens=False
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs prompt["mm_processor_kwargs"] = mm_processor_kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment