test_completion_prompts.py 8.21 KB
Newer Older
1
2
3
4
5
import pytest
import requests
import json
from aiohttp import ClientSession

Nicolas Patry's avatar
Nicolas Patry committed
6
from text_generation.types import Completion, ChatCompletionChunk
7
8
9
10
11


@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
    with launcher(
12
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
    await flash_llama_completion_handle.health(300)
    return flash_llama_completion_handle.client


# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.


27
@pytest.mark.release
28
29
30
31
32
33
34
def test_flash_llama_completion_single_prompt(
    flash_llama_completion, response_snapshot
):
    response = requests.post(
        f"{flash_llama_completion.base_url}/v1/completions",
        json={
            "model": "tgi",
35
36
37
            "prompt": "What is Deep Learning?",
            "max_tokens": 10,
            "temperature": 0.0,
38
39
40
41
42
43
        },
        headers=flash_llama_completion.headers,
        stream=False,
    )
    response = response.json()
    assert len(response["choices"]) == 1
44
45
46
47
    assert (
        response["choices"][0]["text"]
        == " A Beginner’s Guide\nDeep learning is a subset"
    )
48
49
50
    assert response == response_snapshot


Nicolas Patry's avatar
Nicolas Patry committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
    flash_llama_completion, response_snapshot
):
    url = f"{flash_llama_completion.base_url}/v1/chat/completions"
    request = {
        "model": "tgi",
        "messages": [
            {
                "role": "user",
                "content": "What is Deep Learning?",
            }
        ],
        "max_tokens": 10,
        "temperature": 0.0,
        "stream_options": {"include_usage": True},
        "stream": True,
    }
    string = ""
    chunks = []
    had_usage = False
    async with ClientSession(headers=flash_llama_completion.headers) as session:
        async with session.post(url, json=request) as response:
            # iterate over the stream
            async for chunk in response.content.iter_any():
                # remove "data:"
                chunk = chunk.decode().split("\n\n")
                # remove "data:" if present
                chunk = [c.replace("data:", "") for c in chunk]
                # remove empty strings
                chunk = [c for c in chunk if c]
                # remove completion marking chunk
                chunk = [c for c in chunk if c != " [DONE]"]
                # parse json
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    chunks.append(ChatCompletionChunk(**c))
                    assert "choices" in c
                    if len(c["choices"]) == 1:
                        index = c["choices"][0]["index"]
                        assert index == 0
                        string += c["choices"][0]["delta"]["content"]

                        has_usage = c["usage"] is not None
                        assert not had_usage
                        if has_usage:
                            had_usage = True
                    else:
                        raise RuntimeError("Expected different payload")
    assert had_usage
    assert (
        string
        == "**Deep Learning: An Overview**\n=====================================\n\n"
    )
    assert chunks == response_snapshot

    request = {
        "model": "tgi",
        "messages": [
            {
                "role": "user",
                "content": "What is Deep Learning?",
            }
        ],
        "max_tokens": 10,
        "temperature": 0.0,
        "stream": True,
    }
    string = ""
    chunks = []
    had_usage = False
    async with ClientSession(headers=flash_llama_completion.headers) as session:
        async with session.post(url, json=request) as response:
            # iterate over the stream
            async for chunk in response.content.iter_any():
                # remove "data:"
                chunk = chunk.decode().split("\n\n")
                # remove "data:" if present
                chunk = [c.replace("data:", "") for c in chunk]
                # remove empty strings
                chunk = [c for c in chunk if c]
                # remove completion marking chunk
                chunk = [c for c in chunk if c != " [DONE]"]
                # parse json
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    chunks.append(ChatCompletionChunk(**c))
                    assert "choices" in c
                    if len(c["choices"]) == 1:
                        index = c["choices"][0]["index"]
                        assert index == 0
                        string += c["choices"][0]["delta"]["content"]

                        has_usage = c["usage"] is not None
                        assert not had_usage
                        if has_usage:
                            had_usage = True
                    else:
                        raise RuntimeError("Expected different payload")
    assert not had_usage
    assert (
        string
        == "**Deep Learning: An Overview**\n=====================================\n\n"
    )


159
@pytest.mark.release
160
161
162
163
164
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
    response = requests.post(
        f"{flash_llama_completion.base_url}/v1/completions",
        json={
            "model": "tgi",
165
166
167
168
169
170
            "prompt": [
                "What is Deep Learning?",
                "Is water wet?",
                "What is the capital of France?",
                "def mai",
            ],
171
172
            "max_tokens": 10,
            "seed": 0,
173
            "temperature": 0.0,
174
175
176
177
178
179
180
        },
        headers=flash_llama_completion.headers,
        stream=False,
    )
    response = response.json()
    assert len(response["choices"]) == 4

181
    all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
182
    all_indexes.sort()
183
184
185
186
187
188
189
190
    all_indices, all_strings = zip(*all_indexes)
    assert list(all_indices) == [0, 1, 2, 3]
    assert list(all_strings) == [
        " A Beginner’s Guide\nDeep learning is a subset",
        " This is a question that has puzzled many people for",
        " Paris\nWhat is the capital of France?\nThe",
        'usculas_minusculas(s):\n    """\n',
    ]
191
192
193
194

    assert response == response_snapshot


195
@pytest.mark.release
196
197
198
199
200
201
async def test_flash_llama_completion_many_prompts_stream(
    flash_llama_completion, response_snapshot
):
    request = {
        "model": "tgi",
        "prompt": [
202
            "What is Deep Learning?",
203
204
205
206
207
208
            "Is water wet?",
            "What is the capital of France?",
            "def mai",
        ],
        "max_tokens": 10,
        "seed": 0,
209
        "temperature": 0.0,
210
211
212
213
214
215
        "stream": True,
    }

    url = f"{flash_llama_completion.base_url}/v1/completions"

    chunks = []
216
    strings = [""] * 4
217
218
219
220
221
222
223
224
225
226
    async with ClientSession(headers=flash_llama_completion.headers) as session:
        async with session.post(url, json=request) as response:
            # iterate over the stream
            async for chunk in response.content.iter_any():
                # remove "data:"
                chunk = chunk.decode().split("\n\n")
                # remove "data:" if present
                chunk = [c.replace("data:", "") for c in chunk]
                # remove empty strings
                chunk = [c for c in chunk if c]
227
228
                # remove completion marking chunk
                chunk = [c for c in chunk if c != " [DONE]"]
229
230
231
232
233
234
                # parse json
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    chunks.append(Completion(**c))
                    assert "choices" in c
235
236
237
                    index = c["choices"][0]["index"]
                    assert 0 <= index <= 4
                    strings[index] += c["choices"][0]["text"]
238
239

    assert response.status == 200
240
241
242
243
244
245
    assert list(strings) == [
        " A Beginner’s Guide\nDeep learning is a subset",
        " This is a question that has puzzled many people for",
        " Paris\nWhat is the capital of France?\nThe",
        'usculas_minusculas(s):\n    """\n',
    ]
246
    assert chunks == response_snapshot