test_openai_server.py 3.7 KB
Newer Older
1
2
from argparse import Namespace
from dataclasses import dataclass
Simon Mo's avatar
Simon Mo committed
3
4
import os
import pathlib
5
6
7
8
9
10

import pytest
from fastapi.testclient import TestClient

from vllm.entrypoints.openai.api_server import *

Simon Mo's avatar
Simon Mo committed
11
12
13
14
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
    __file__))).parent.parent / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

15
16
17
18
19
20
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
    ("facebook/opt-125m", None, True,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
    ("facebook/opt-125m", None, False,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
Simon Mo's avatar
Simon Mo committed
21
    ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
22
23
24
25
26
27
28
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
Simon Mo's avatar
Simon Mo committed
29
    ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
]

TEST_MESSAGES = [
    {
        'role': 'user',
        'content': 'Hello'
    },
    {
        'role': 'assistant',
        'content': 'Hi there!'
    },
    {
        'role': 'user',
        'content': 'What is the capital of'
    },
]
client = TestClient(app)


@dataclass
class MockTokenizer:
    chat_template = None


def test_load_chat_template():
    # Testing chatml template
Simon Mo's avatar
Simon Mo committed
61
    mock_args = Namespace(chat_template=chatml_jinja_path)
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    tokenizer = MockTokenizer()

    # Call the function with the mocked args
    load_chat_template(mock_args, tokenizer)

    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""


def test_no_load_chat_template():
    # Testing chatml template
    template = "../../examples/does_not_exist"
    mock_args = Namespace(chat_template=template)
    tokenizer = MockTokenizer()

    # Call the function with the mocked args
    load_chat_template(mock_args, tokenizer=tokenizer)
    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """../../examples/does_not_exist"""


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "model,template,add_generation_prompt,expected_output",
    MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
                              expected_output):
    # Initialize the tokenizer
    tokenizer = get_tokenizer(tokenizer_name=model)

    mock_args = Namespace(chat_template=template)
    load_chat_template(mock_args, tokenizer)

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
        messages=TEST_MESSAGES,
        add_generation_prompt=add_generation_prompt)

    # Call the function and get the result
    result = tokenizer.apply_chat_template(
        conversation=mock_request.messages,
        tokenize=False,
        add_generation_prompt=mock_request.add_generation_prompt)

    # Test assertion
    assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"


def test_health_endpoint():
    response = client.get("/health")
    assert response.status_code == 200