test_grammar_llama.py 4.62 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pytest
import json

from text_generation.types import GrammarType


@pytest.fixture(scope="module")
def flash_llama_grammar_handle(launcher):
    with launcher(
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_llama_grammar(flash_llama_grammar_handle):
    await flash_llama_grammar_handle.health(300)
    return flash_llama_grammar_handle.client


@pytest.mark.asyncio
async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot):
    response = await flash_llama_grammar.generate(
        "Test request", max_new_tokens=10, decoder_input_details=True
    )

    assert response.details.generated_tokens == 10
    assert response == response_snapshot


@pytest.mark.asyncio
async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot):
    response = await flash_llama_grammar.generate(
        "Whats Googles DNS",
        max_new_tokens=10,
        decoder_input_details=True,
        seed=0,
        grammar={
            "type": GrammarType.Regex,  # "regex"
            "value": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
        },
    )

    assert response.details.generated_tokens == 10
    assert response.generated_text == "42.1.1.101"
    assert response == response_snapshot


@pytest.mark.asyncio
async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot):
    response = await flash_llama_grammar.generate(
        "info: david holtz like trees and has two cats. ",
        max_new_tokens=100,
        decoder_input_details=True,
        seed=0,
        grammar={
            "type": GrammarType.Json,  # "json"
            "value": json.dumps(
                {
                    "type": "object",
                    "$id": "https://example.com/person.schema.json",
                    "$schema": "https://json-schema.org/draft/2020-12/schema",
                    "title": "Person",
                    "properties": {
                        "firstName": {
                            "type": "string",
                            "description": "The person'''s first name.",
                        },
                        "lastName": {
                            "type": "string",
                            "description": "The person'''s last name.",
                        },
                        "hobby": {
                            "description": "The person'''s hobby.",
                            "type": "string",
                        },
                        "numCats": {
                            "description": "The number of cats the person has.",
                            "type": "integer",
                            "minimum": 0,
                        },
                    },
                    "required": ["firstName", "lastName", "hobby", "numCats"],
                }
            ),
        },
    )

    assert response.details.generated_tokens == 30
    assert (
        response.generated_text
92
        == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
drbh's avatar
drbh committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    )
    assert response == response_snapshot


@pytest.mark.asyncio
async def test_flash_llama_grammar_load(
    flash_llama_grammar, generate_load, response_snapshot
):
    responses = await generate_load(
        flash_llama_grammar,
        "name: david. email:  ",
        max_new_tokens=10,
        n=4,
        stop_sequences=[".com"],
        seed=0,
        grammar={
            "type": GrammarType.Regex,  # "regex"
            "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+",  # email regex
        },
    )

    assert len(responses) == 4

    expected = "123456@gmail.com"

    for response in responses:
        assert response.generated_text == expected

    assert all([r.generated_text == responses[0].generated_text for r in responses])

    assert responses == response_snapshot


# this is the same as the above test, but only fires off a single request
# this is only to ensure that the parallel and single inference produce the same result
@pytest.mark.asyncio
async def test_flash_llama_grammar_single_load_instance(
    flash_llama_grammar, generate_load, response_snapshot
):
    response = await flash_llama_grammar.generate(
        "name: david. email:  ",
        max_new_tokens=10,
        stop_sequences=[".com"],
        seed=0,
        grammar={
            "type": GrammarType.Regex,  # "regex"
            "value": "[\\w-]+@([\\w-]+\\.)+[\\w-]+",  # email regex
        },
    )

    # assert response.details.generated_tokens == 30
    assert response.generated_text == "123456@gmail.com"

    assert response == response_snapshot