test_chat_template.py 3.29 KB
Newer Older
1
import pytest
2
import os
3

4
5
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
                                         load_chat_template)
6
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
7
from vllm.transformers_utils.tokenizer import get_tokenizer
8

9
from ...utils import VLLM_PATH, models_path_prefix
10
11

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
Simon Mo's avatar
Simon Mo committed
12
13
assert chatml_jinja_path.exists()

14
15
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
16
    (os.path.join(models_path_prefix, "facebook/opt-125m"), chatml_jinja_path, True, """<|im_start|>user
17
18
19
20
21
22
23
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
24
    (os.path.join(models_path_prefix, "facebook/opt-125m"), chatml_jinja_path, False, """<|im_start|>user
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
]

TEST_MESSAGES = [
    {
        'role': 'user',
        'content': 'Hello'
    },
    {
        'role': 'assistant',
        'content': 'Hi there!'
    },
    {
        'role': 'user',
        'content': 'What is the capital of'
    },
]


48
def test_load_chat_template():
49
    # Testing chatml template
50
    template_content = load_chat_template(chat_template=chatml_jinja_path)
51
52
53
54
55

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
56
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""  # noqa: E501
57
58


59
def test_no_load_chat_template_filelike():
60
61
    # Testing chatml template
    template = "../../examples/does_not_exist"
62
63

    with pytest.raises(ValueError, match="looks like a file path"):
64
        load_chat_template(chat_template=template)
65
66


67
def test_no_load_chat_template_literallike():
68
69
70
    # Testing chatml template
    template = "{{ messages }}"

71
    template_content = load_chat_template(chat_template=template)
72

73
    assert template_content == template
74
75
76
77
78


@pytest.mark.parametrize(
    "model,template,add_generation_prompt,expected_output",
    MODEL_TEMPLATE_GENERATON_OUTPUT)
79
80
def test_get_gen_prompt(model, template, add_generation_prompt,
                        expected_output):
81
82
    # Initialize the tokenizer
    tokenizer = get_tokenizer(tokenizer_name=model)
83
    template_content = load_chat_template(chat_template=template)
84
85
86
87
88
89
90
91

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
        messages=TEST_MESSAGES,
        add_generation_prompt=add_generation_prompt)

    # Call the function and get the result
92
    result = apply_hf_chat_template(
93
        tokenizer,
94
        conversation=mock_request.messages,
95
        chat_template=mock_request.chat_template or template_content,
96
        add_generation_prompt=mock_request.add_generation_prompt,
97
    )
98
99

    # Test assertion
100
101
102
    assert result == expected_output, (
        f"The generated prompt does not match the expected output for "
        f"model {model} and template {template}")