test_utils.py 2.15 KB
Newer Older
1
2
import pytest

3
4
from huggingface_hub.utils import RevisionNotFoundError

5
6
7
8
from text_generation.utils import (
    weight_hub_files,
    download_weights,
    weight_files,
9
10
    StopSequenceCriteria,
    StoppingCriteria,
11
    LocalEntryNotFoundError,
12
    FinishReason,
13
14
15
)


16
def test_stop_sequence_criteria():
17
    criteria = StopSequenceCriteria("/test;")
18

19
20
21
22
    assert not criteria("/")
    assert not criteria("/test")
    assert criteria("/test;")
    assert not criteria("/test; ")
23
24


25
26
27
def test_stopping_criteria():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(65827, "/test") == (False, None)
28
    assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
29
30


31
32
33
def test_stopping_criteria_eos():
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
34
    assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
35
36
37


def test_stopping_criteria_max():
38
39
40
41
42
    criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
    assert criteria(1, "") == (False, None)
43
    assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
44
45


46
47
48
49
50
51
52
53
54
55
56
def test_weight_hub_files():
    filenames = weight_hub_files("bigscience/bloom-560m")
    assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
    filenames = weight_hub_files("bigscience/bloom")
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]


def test_weight_hub_files_empty():
57
    filenames = weight_hub_files("bigscience/bloom", extension=".errors")
58
59
60
61
62
63
64
65
66
67
    assert filenames == []


def test_download_weights():
    files = download_weights("bigscience/bloom-560m")
    local_files = weight_files("bigscience/bloom-560m")
    assert files == local_files


def test_weight_files_error():
68
69
    with pytest.raises(RevisionNotFoundError):
        weight_files("bigscience/bloom-560m", revision="error")
70
71
    with pytest.raises(LocalEntryNotFoundError):
        weight_files("bert-base-uncased")