test_completion_prompts.py 4.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest
import requests
import json
from aiohttp import ClientSession

from text_generation.types import (
    Completion,
)


@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
    with launcher(
14
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
    await flash_llama_completion_handle.health(300)
    return flash_llama_completion_handle.client


# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.


29
@pytest.mark.release
30
31
32
33
34
35
36
def test_flash_llama_completion_single_prompt(
    flash_llama_completion, response_snapshot
):
    response = requests.post(
        f"{flash_llama_completion.base_url}/v1/completions",
        json={
            "model": "tgi",
37
38
39
            "prompt": "What is Deep Learning?",
            "max_tokens": 10,
            "temperature": 0.0,
40
41
42
43
44
45
        },
        headers=flash_llama_completion.headers,
        stream=False,
    )
    response = response.json()
    assert len(response["choices"]) == 1
46
47
48
49
    assert (
        response["choices"][0]["text"]
        == " A Beginner’s Guide\nDeep learning is a subset"
    )
50
51
52
    assert response == response_snapshot


53
@pytest.mark.release
54
55
56
57
58
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
    response = requests.post(
        f"{flash_llama_completion.base_url}/v1/completions",
        json={
            "model": "tgi",
59
60
61
62
63
64
            "prompt": [
                "What is Deep Learning?",
                "Is water wet?",
                "What is the capital of France?",
                "def mai",
            ],
65
66
            "max_tokens": 10,
            "seed": 0,
67
            "temperature": 0.0,
68
69
70
71
72
73
74
        },
        headers=flash_llama_completion.headers,
        stream=False,
    )
    response = response.json()
    assert len(response["choices"]) == 4

75
    all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
76
    all_indexes.sort()
77
78
79
80
81
82
83
84
    all_indices, all_strings = zip(*all_indexes)
    assert list(all_indices) == [0, 1, 2, 3]
    assert list(all_strings) == [
        " A Beginner’s Guide\nDeep learning is a subset",
        " This is a question that has puzzled many people for",
        " Paris\nWhat is the capital of France?\nThe",
        'usculas_minusculas(s):\n    """\n',
    ]
85
86
87
88

    assert response == response_snapshot


89
@pytest.mark.release
90
91
92
93
94
95
async def test_flash_llama_completion_many_prompts_stream(
    flash_llama_completion, response_snapshot
):
    request = {
        "model": "tgi",
        "prompt": [
96
            "What is Deep Learning?",
97
98
99
100
101
102
            "Is water wet?",
            "What is the capital of France?",
            "def mai",
        ],
        "max_tokens": 10,
        "seed": 0,
103
        "temperature": 0.0,
104
105
106
107
108
109
        "stream": True,
    }

    url = f"{flash_llama_completion.base_url}/v1/completions"

    chunks = []
110
    strings = [""] * 4
111
112
113
114
115
116
117
118
119
120
    async with ClientSession(headers=flash_llama_completion.headers) as session:
        async with session.post(url, json=request) as response:
            # iterate over the stream
            async for chunk in response.content.iter_any():
                # remove "data:"
                chunk = chunk.decode().split("\n\n")
                # remove "data:" if present
                chunk = [c.replace("data:", "") for c in chunk]
                # remove empty strings
                chunk = [c for c in chunk if c]
121
122
                # remove completion marking chunk
                chunk = [c for c in chunk if c != " [DONE]"]
123
124
125
126
127
128
                # parse json
                chunk = [json.loads(c) for c in chunk]

                for c in chunk:
                    chunks.append(Completion(**c))
                    assert "choices" in c
129
130
131
                    index = c["choices"][0]["index"]
                    assert 0 <= index <= 4
                    strings[index] += c["choices"][0]["text"]
132
133

    assert response.status == 200
134
135
136
137
138
139
    assert list(strings) == [
        " A Beginner’s Guide\nDeep learning is a subset",
        " This is a question that has puzzled many people for",
        " Paris\nWhat is the capital of France?\nThe",
        'usculas_minusculas(s):\n    """\n',
    ]
140
    assert chunks == response_snapshot