test_grammar_llama.py 2.37 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
import pytest
import json

from text_generation.types import GrammarType


@pytest.fixture(scope="module")
8
def non_flash_llama_grammar_handle(launcher):
drbh's avatar
drbh committed
9
    with launcher(
10
11
12
13
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        num_shard=1,
        disable_grammar_support=False,
        use_flash_attention=False,
drbh's avatar
drbh committed
14
15
16
17
18
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
19
20
21
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
    await non_flash_llama_grammar_handle.health(300)
    return non_flash_llama_grammar_handle.client
drbh's avatar
drbh committed
22
23


xuxzh1's avatar
last  
xuxzh1 committed
24
@pytest.mark.release
25
@pytest.mark.skip
drbh's avatar
drbh committed
26
@pytest.mark.asyncio
27
28
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
    response = await non_flash_llama_grammar.generate(
drbh's avatar
drbh committed
29
30
31
32
33
        "info: david holtz like trees and has two cats. ",
        max_new_tokens=100,
        decoder_input_details=True,
        seed=0,
        grammar={
34
            "type": GrammarType.Json,
drbh's avatar
drbh committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            "value": json.dumps(
                {
                    "type": "object",
                    "$id": "https://example.com/person.schema.json",
                    "$schema": "https://json-schema.org/draft/2020-12/schema",
                    "title": "Person",
                    "properties": {
                        "firstName": {
                            "type": "string",
                            "description": "The person'''s first name.",
                        },
                        "lastName": {
                            "type": "string",
                            "description": "The person'''s last name.",
                        },
                        "hobby": {
                            "description": "The person'''s hobby.",
                            "type": "string",
                        },
                        "numCats": {
                            "description": "The number of cats the person has.",
                            "type": "integer",
                            "minimum": 0,
                        },
                    },
                    "required": ["firstName", "lastName", "hobby", "numCats"],
                }
            ),
        },
    )

    assert response.details.generated_tokens == 30
    assert (
        response.generated_text
69
        == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
drbh's avatar
drbh committed
70
71
    )
    assert response == response_snapshot