"vscode:/vscode.git/clone" did not exist on "5522d0a36cb6b060cbd37307e0652d6f46d7e5d0"
test_openai_server.py 3.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from argparse import Namespace
from dataclasses import dataclass

import pytest
from fastapi.testclient import TestClient

from vllm.entrypoints.openai.api_server import *

# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
    ("facebook/opt-125m", None, True,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
    ("facebook/opt-125m", None, False,
     "Hello</s>Hi there!</s>What is the capital of</s>"),
    ("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
     """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
    ("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
     """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
]

TEST_MESSAGES = [
    {
        'role': 'user',
        'content': 'Hello'
    },
    {
        'role': 'assistant',
        'content': 'Hi there!'
    },
    {
        'role': 'user',
        'content': 'What is the capital of'
    },
]
client = TestClient(app)


@dataclass
class MockTokenizer:
    chat_template = None


def test_load_chat_template():
    # Testing chatml template
    template = "../../examples/template_chatml.jinja"
    mock_args = Namespace(chat_template=template)
    tokenizer = MockTokenizer()

    # Call the function with the mocked args
    load_chat_template(mock_args, tokenizer)

    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""


def test_no_load_chat_template():
    # Testing chatml template
    template = "../../examples/does_not_exist"
    mock_args = Namespace(chat_template=template)
    tokenizer = MockTokenizer()

    # Call the function with the mocked args
    load_chat_template(mock_args, tokenizer=tokenizer)
    template_content = tokenizer.chat_template

    # Test assertions
    assert template_content is not None
    # Hard coded value for template_chatml.jinja
    assert template_content == """../../examples/does_not_exist"""


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "model,template,add_generation_prompt,expected_output",
    MODEL_TEMPLATE_GENERATON_OUTPUT)
async def test_get_gen_prompt(model, template, add_generation_prompt,
                              expected_output):
    # Initialize the tokenizer
    tokenizer = get_tokenizer(tokenizer_name=model)

    mock_args = Namespace(chat_template=template)
    load_chat_template(mock_args, tokenizer)

    # Create a mock request object using keyword arguments
    mock_request = ChatCompletionRequest(
        model=model,
        messages=TEST_MESSAGES,
        add_generation_prompt=add_generation_prompt)

    # Call the function and get the result
    result = tokenizer.apply_chat_template(
        conversation=mock_request.messages,
        tokenize=False,
        add_generation_prompt=mock_request.add_generation_prompt)

    # Test assertion
    assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"


def test_health_endpoint():
    response = client.get("/health")
    assert response.status_code == 200