test_chat_template.py 4.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pytest
5
import os
6

7
from vllm.config import ModelConfig
8
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
9
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
10
from vllm.tokenizers import get_tokenizer
11

12
from ...models.registry import HF_EXAMPLE_MODELS
13
from ...utils import VLLM_PATH, models_path_prefix
14
15

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
Simon Mo's avatar
Simon Mo committed
16
17
assert chatml_jinja_path.exists()

18
# Define models, templates, and their corresponding expected outputs
19
MODEL_TEMPLATE_GENERATION_OUTPUT = [
20
    (
21
        os.path.join(models_path_prefix, "facebook/opt-125m"),
22
23
24
25
        chatml_jinja_path,
        True,
        False,
        """<|im_start|>user
26
27
28
29
30
31
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
32
33
34
""",
    ),
    (
35
        os.path.join(models_path_prefix, "facebook/opt-125m"),
36
37
38
39
        chatml_jinja_path,
        False,
        False,
        """<|im_start|>user
40
41
42
43
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
44
45
46
What is the capital of""",
    ),
    (
47
        os.path.join(models_path_prefix, "facebook/opt-125m"),
48
49
50
51
        chatml_jinja_path,
        False,
        True,
        """<|im_start|>user
52
53
54
55
56
57
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
58
59
The capital of""",
    ),
60
61
62
]

TEST_MESSAGES = [
63
64
65
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hi there!"},
    {"role": "user", "content": "What is the capital of"},
66
]
67
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
68
69


70
def test_load_chat_template():
71
    # Testing chatml template
72
    template_content = load_chat_template(chat_template=chatml_jinja_path)
73
74
75
76

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
77
78
79
    assert (
        template_content
        == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
80
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""  # noqa: E501
81
    )
82
83


84
def test_no_load_chat_template_filelike():
85
86
    # Testing chatml template
    template = "../../examples/does_not_exist"
87
88

    with pytest.raises(ValueError, match="looks like a file path"):
89
        load_chat_template(chat_template=template)
90
91


92
def test_no_load_chat_template_literallike():
93
94
95
    # Testing chatml template
    template = "{{ messages }}"

96
    template_content = load_chat_template(chat_template=template)
97

98
    assert template_content == template
99
100
101


@pytest.mark.parametrize(
102
    "model,template,add_generation_prompt,continue_final_message,expected_output",
103
104
105
106
107
    MODEL_TEMPLATE_GENERATION_OUTPUT,
)
def test_get_gen_prompt(
    model, template, add_generation_prompt, continue_final_message, expected_output
):
108
109
110
111
112
113
114
115
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
    model_info.check_available_online(on_fail="skip")

    model_config = ModelConfig(
        model,
        tokenizer=model_info.tokenizer or model,
        tokenizer_mode=model_info.tokenizer_mode,
        trust_remote_code=model_info.trust_remote_code,
116
        revision=model_info.revision,
117
        hf_overrides=model_info.hf_overrides,
118
119
120
        skip_tokenizer_init=model_info.require_embed_inputs,
        enable_prompt_embeds=model_info.require_embed_inputs,
        enable_mm_embeds=model_info.require_embed_inputs,
121
        enforce_eager=model_info.enforce_eager,
122
123
        dtype=model_info.dtype,
    )
124

125
    # Initialize the tokenizer
126
127
128
129
    tokenizer = get_tokenizer(
        tokenizer_name=model_config.tokenizer,
        trust_remote_code=model_config.trust_remote_code,
    )
130
    template_content = load_chat_template(chat_template=template)
131
132
133
134

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
135
        messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
136
137
        if continue_final_message
        else TEST_MESSAGES,
138
139
140
        add_generation_prompt=add_generation_prompt,
        continue_final_message=continue_final_message,
    )
141
142

    # Call the function and get the result
143
    result = apply_hf_chat_template(
144
        tokenizer=tokenizer,
145
        conversation=mock_request.messages,
146
        chat_template=mock_request.chat_template or template_content,
147
        model_config=model_config,
148
        tools=None,
149
        add_generation_prompt=mock_request.add_generation_prompt,
150
        continue_final_message=mock_request.continue_final_message,
151
    )
152
153

    # Test assertion
154
155
    assert result == expected_output, (
        f"The generated prompt does not match the expected output for "
156
157
        f"model {model} and template {template}"
    )