test_chat_template.py 4.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pytest
5
import os
6

7
from vllm.config import ModelConfig
8
9
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
                                         load_chat_template)
10
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
11
from vllm.transformers_utils.tokenizer import get_tokenizer
12

13
from ...models.registry import HF_EXAMPLE_MODELS
14
from ...utils import VLLM_PATH, models_path_prefix
15
16

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
Simon Mo's avatar
Simon Mo committed
17
18
assert chatml_jinja_path.exists()

19
20
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
zhuwenwen's avatar
zhuwenwen committed
21
    (os.path.join(models_path_prefix, "facebook/opt-125m"), chatml_jinja_path, True, False, """<|im_start|>user
22
23
24
25
26
27
28
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
zhuwenwen's avatar
zhuwenwen committed
29
    (os.path.join(models_path_prefix, "facebook/opt-125m"), chatml_jinja_path, False, False, """<|im_start|>user
30
31
32
33
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
34
What is the capital of"""),
zhuwenwen's avatar
zhuwenwen committed
35
    (os.path.join(models_path_prefix, "facebook/opt-125m"), chatml_jinja_path, False, True, """<|im_start|>user
36
37
38
39
40
41
42
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""),
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
]

TEST_MESSAGES = [
    {
        'role': 'user',
        'content': 'Hello'
    },
    {
        'role': 'assistant',
        'content': 'Hi there!'
    },
    {
        'role': 'user',
        'content': 'What is the capital of'
    },
]
59
60
61
62
ASSISTANT_MESSAGE_TO_CONTINUE = {
    'role': 'assistant',
    'content': 'The capital of'
}
63
64


65
def test_load_chat_template():
66
    # Testing chatml template
67
    template_content = load_chat_template(chat_template=chatml_jinja_path)
68
69
70
71
72

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
73
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""  # noqa: E501
74
75


76
def test_no_load_chat_template_filelike():
77
78
    # Testing chatml template
    template = "../../examples/does_not_exist"
79
80

    with pytest.raises(ValueError, match="looks like a file path"):
81
        load_chat_template(chat_template=template)
82
83


84
def test_no_load_chat_template_literallike():
85
86
87
    # Testing chatml template
    template = "{{ messages }}"

88
    template_content = load_chat_template(chat_template=template)
89

90
    assert template_content == template
91
92
93


@pytest.mark.parametrize(
94
    "model,template,add_generation_prompt,continue_final_message,expected_output",
95
    MODEL_TEMPLATE_GENERATON_OUTPUT)
96
def test_get_gen_prompt(model, template, add_generation_prompt,
97
                        continue_final_message, expected_output):
98
99
100
101
102
103
104
105
106
107
108
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")

    model_config = ModelConfig(
        model,
        tokenizer=model_info.tokenizer or model,
        tokenizer_mode=model_info.tokenizer_mode,
        trust_remote_code=model_info.trust_remote_code,
        hf_overrides=model_info.hf_overrides,
    )

109
    # Initialize the tokenizer
110
111
112
113
    tokenizer = get_tokenizer(
        tokenizer_name=model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code,
    )
114
    template_content = load_chat_template(chat_template=template)
115
116
117
118

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
119
120
121
122
123
        messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
        if continue_final_message else TEST_MESSAGES,
        add_generation_prompt=add_generation_prompt,
        continue_final_message=continue_final_message,
    )
124
125

    # Call the function and get the result
126
    result = apply_hf_chat_template(
127
        tokenizer=tokenizer,
128
        conversation=mock_request.messages,
129
        chat_template=mock_request.chat_template or template_content,
130
        model_config=model_config,
131
        tools=None,
132
        add_generation_prompt=mock_request.add_generation_prompt,
133
        continue_final_message=mock_request.continue_final_message,
134
    )
135
136

    # Test assertion
137
138
139
    assert result == expected_output, (
        f"The generated prompt does not match the expected output for "
        f"model {model} and template {template}")