test_grammar_llama.py 2.35 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
import pytest
import json

from text_generation.types import GrammarType


@pytest.fixture(scope="module")
8
def non_flash_llama_grammar_handle(launcher):
drbh's avatar
drbh committed
9
    with launcher(
10
11
12
13
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        num_shard=1,
        disable_grammar_support=False,
        use_flash_attention=False,
drbh's avatar
drbh committed
14
15
16
17
18
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
19
20
21
async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
    await non_flash_llama_grammar_handle.health(300)
    return non_flash_llama_grammar_handle.client
drbh's avatar
drbh committed
22
23


24
@pytest.mark.skip
drbh's avatar
drbh committed
25
@pytest.mark.asyncio
26
27
async def test_non_flash_llama_grammar_json(non_flash_llama_grammar, response_snapshot):
    response = await non_flash_llama_grammar.generate(
drbh's avatar
drbh committed
28
29
30
31
32
        "info: david holtz like trees and has two cats. ",
        max_new_tokens=100,
        decoder_input_details=True,
        seed=0,
        grammar={
33
            "type": GrammarType.Json,
drbh's avatar
drbh committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
            "value": json.dumps(
                {
                    "type": "object",
                    "$id": "https://example.com/person.schema.json",
                    "$schema": "https://json-schema.org/draft/2020-12/schema",
                    "title": "Person",
                    "properties": {
                        "firstName": {
                            "type": "string",
                            "description": "The person'''s first name.",
                        },
                        "lastName": {
                            "type": "string",
                            "description": "The person'''s last name.",
                        },
                        "hobby": {
                            "description": "The person'''s hobby.",
                            "type": "string",
                        },
                        "numCats": {
                            "description": "The number of cats the person has.",
                            "type": "integer",
                            "minimum": 0,
                        },
                    },
                    "required": ["firstName", "lastName", "hobby", "numCats"],
                }
            ),
        },
    )

    assert response.details.generated_tokens == 30
    assert (
        response.generated_text
68
        == '{"firstName":"David","hobby":"Trees","lastName":"Holtz","numCats":2}'
drbh's avatar
drbh committed
69
70
    )
    assert response == response_snapshot