test_openapi_server_ray.py 3.7 KB
Newer Older
1
2
3
import openai  # use the official client for correctness check
import pytest

4
from ..utils import VLLM_PATH, RemoteOpenAIServer
5

6
7
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
8
9
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
10
11


12
@pytest.fixture(scope="module")
13
def server():
14
15
16
17
18
19
20
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--enforce-eager",
21
22
23
        "--engine-use-ray",
        "--chat-template",
        str(chatml_jinja_path),
24
25
    ]

26
27
28
29
30
    # Allow `--engine-use-ray`, otherwise the launch of the server throw
    # an error due to try to use a deprecated feature
    env_dict = {"VLLM_ALLOW_ENGINE_USE_RAY": "1"}
    with RemoteOpenAIServer(MODEL_NAME, args,
                            env_dict=env_dict) as remote_server:
31
        yield remote_server
32
33


34
@pytest.fixture(scope="module")
35
36
def client(server):
    return server.get_async_client()
37
38
39


@pytest.mark.asyncio
40
async def test_check_models(client: openai.AsyncOpenAI):
41
42
43
44
45
46
47
48
    models = await client.models.list()
    models = models.data
    served_model = models[0]
    assert served_model.id == MODEL_NAME
    assert all(model.root == MODEL_NAME for model in models)


@pytest.mark.asyncio
49
async def test_single_completion(client: openai.AsyncOpenAI):
50
51
52
53
54
55
    completion = await client.completions.create(model=MODEL_NAME,
                                                 prompt="Hello, my name is",
                                                 max_tokens=5,
                                                 temperature=0.0)

    assert completion.id is not None
56
57
    assert len(completion.choices) == 1
    assert len(completion.choices[0].text) >= 5
58
59
60
61
62
63
64
65
66
67
68
    assert completion.choices[0].finish_reason == "length"
    assert completion.usage == openai.types.CompletionUsage(
        completion_tokens=5, prompt_tokens=6, total_tokens=11)

    # test using token IDs
    completion = await client.completions.create(
        model=MODEL_NAME,
        prompt=[0, 0, 0, 0, 0],
        max_tokens=5,
        temperature=0.0,
    )
69
    assert len(completion.choices[0].text) >= 5
70
71
72


@pytest.mark.asyncio
73
async def test_single_chat_session(client: openai.AsyncOpenAI):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    messages = [{
        "role": "system",
        "content": "you are a helpful assistant"
    }, {
        "role": "user",
        "content": "what is 1+1?"
    }]

    # test single completion
    chat_completion = await client.chat.completions.create(model=MODEL_NAME,
                                                           messages=messages,
                                                           max_tokens=10,
                                                           logprobs=True,
                                                           top_logprobs=5)
    assert chat_completion.id is not None
89
90
91
92
93
    assert len(chat_completion.choices) == 1

    choice = chat_completion.choices[0]
    assert choice.finish_reason == "length"
    assert chat_completion.usage == openai.types.CompletionUsage(
94
        completion_tokens=10, prompt_tokens=55, total_tokens=65)
95
96

    message = choice.message
97
98
99
100
101
102
103
104
105
106
107
108
109
    assert message.content is not None and len(message.content) >= 10
    assert message.role == "assistant"
    messages.append({"role": "assistant", "content": message.content})

    # test multi-turn dialogue
    messages.append({"role": "user", "content": "express your result in json"})
    chat_completion = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        max_tokens=10,
    )
    message = chat_completion.choices[0].message
    assert message.content is not None and len(message.content) >= 0