test_model.py 1.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import pytest
import torch

from transformers import AutoTokenizer

from text_generation_server.models import Model


def get_test_model():
    class TestModel(Model):
        def batch_type(self):
            raise NotImplementedError

        def generate_token(self, batch):
            raise NotImplementedError

    tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")

    model = TestModel(
drbh's avatar
drbh committed
20
21
22
23
24
25
        "test_model_id",
        torch.nn.Linear(1, 1),
        tokenizer,
        False,
        torch.float32,
        torch.device("cpu"),
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    )
    return model


@pytest.mark.private
def test_decode_streaming_english_spaces():
    model = get_test_model()
    truth = "Hello here, this is a simple test"
    all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
    assert (
        all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
    )

    decoded_text = ""
    offset = 0
    token_offset = 0
    for i in range(len(all_input_ids)):
        text, offset, token_offset = model.decode_token(
            all_input_ids[: i + 1], offset, token_offset
        )
        decoded_text += text

    assert decoded_text == truth


@pytest.mark.private
def test_decode_streaming_chinese_utf8():
    model = get_test_model()
    truth = "我很感谢你的热情"
    all_input_ids = [
        30672,
        232,
        193,
        139,
        233,
        135,
        162,
        235,
        179,
        165,
        30919,
        30210,
        234,
        134,
        176,
        30993,
    ]

    decoded_text = ""
    offset = 0
    token_offset = 0
    for i in range(len(all_input_ids)):
        text, offset, token_offset = model.decode_token(
            all_input_ids[: i + 1], offset, token_offset
        )
        decoded_text += text

    assert decoded_text == truth