test_openapi_server_ray.py 3.47 KB
Newer Older
1
2
3
4
5
6
import openai  # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray

7
from ..utils import RemoteOpenAIServer
8

9
10
11
12
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"


13
@pytest.fixture(scope="module")
14
def ray_ctx():
15
    ray.init()
16
17
18
19
20
21
22
    yield
    ray.shutdown()


@pytest.fixture(scope="module")
def server(ray_ctx):
    return RemoteOpenAIServer([
23
24
25
26
27
28
29
30
31
32
33
34
        "--model",
        MODEL_NAME,
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--max-model-len",
        "2048",
        "--enforce-eager",
        "--engine-use-ray"
    ])


35
@pytest.fixture(scope="module")
36
37
def client(server):
    return server.get_async_client()
38
39
40


@pytest.mark.asyncio
41
async def test_check_models(client: openai.AsyncOpenAI):
42
43
44
45
46
47
48
49
    models = await client.models.list()
    models = models.data
    served_model = models[0]
    assert served_model.id == MODEL_NAME
    assert all(model.root == MODEL_NAME for model in models)


@pytest.mark.asyncio
50
async def test_single_completion(client: openai.AsyncOpenAI):
51
52
53
54
55
56
    completion = await client.completions.create(model=MODEL_NAME,
                                                 prompt="Hello, my name is",
                                                 max_tokens=5,
                                                 temperature=0.0)

    assert completion.id is not None
57
58
    assert len(completion.choices) == 1
    assert len(completion.choices[0].text) >= 5
59
60
61
62
63
64
65
66
67
68
69
    assert completion.choices[0].finish_reason == "length"
    assert completion.usage == openai.types.CompletionUsage(
        completion_tokens=5, prompt_tokens=6, total_tokens=11)

    # test using token IDs
    completion = await client.completions.create(
        model=MODEL_NAME,
        prompt=[0, 0, 0, 0, 0],
        max_tokens=5,
        temperature=0.0,
    )
70
    assert len(completion.choices[0].text) >= 5
71
72
73


@pytest.mark.asyncio
74
async def test_single_chat_session(client: openai.AsyncOpenAI):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    messages = [{
        "role": "system",
        "content": "you are a helpful assistant"
    }, {
        "role": "user",
        "content": "what is 1+1?"
    }]

    # test single completion
    chat_completion = await client.chat.completions.create(model=MODEL_NAME,
                                                           messages=messages,
                                                           max_tokens=10,
                                                           logprobs=True,
                                                           top_logprobs=5)
    assert chat_completion.id is not None
90
91
92
93
94
95
96
97
    assert len(chat_completion.choices) == 1

    choice = chat_completion.choices[0]
    assert choice.finish_reason == "length"
    assert chat_completion.usage == openai.types.CompletionUsage(
        completion_tokens=10, prompt_tokens=13, total_tokens=23)

    message = choice.message
98
99
100
101
102
103
104
105
106
107
108
109
110
    assert message.content is not None and len(message.content) >= 10
    assert message.role == "assistant"
    messages.append({"role": "assistant", "content": message.content})

    # test multi-turn dialogue
    messages.append({"role": "user", "content": "express your result in json"})
    chat_completion = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        max_tokens=10,
    )
    message = chat_completion.choices[0].message
    assert message.content is not None and len(message.content) >= 0