test_utils.py 2.2 KB
Newer Older
1
2
3
4
5
6
import pytest

from text_generation.utils import (
    weight_hub_files,
    download_weights,
    weight_files,
7
8
    StopSequenceCriteria,
    StoppingCriteria,
9
10
11
12
    LocalEntryNotFoundError,
)


13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def test_stop_sequence_criteria():
    criteria = StopSequenceCriteria([1, 2, 3])

    assert not criteria(1)
    assert criteria.current_token_idx == 1
    assert not criteria(2)
    assert criteria.current_token_idx == 2
    assert criteria(3)
    assert criteria.current_token_idx == 3


def test_stop_sequence_criteria_reset():
    criteria = StopSequenceCriteria([1, 2, 3])

    assert not criteria(1)
    assert criteria.current_token_idx == 1
    assert not criteria(2)
    assert criteria.current_token_idx == 2
    assert not criteria(4)
    assert criteria.current_token_idx == 0


def test_stop_sequence_criteria_empty():
    with pytest.raises(ValueError):
        StopSequenceCriteria([])


def test_stopping_criteria():
    criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
    assert criteria([1]) == (False, None)
    assert criteria([1, 2]) == (False, None)
    assert criteria([1, 2, 3]) == (True, "stop_sequence")


def test_stopping_criteria_max():
    criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
    assert criteria([1]) == (False, None)
    assert criteria([1, 1]) == (False, None)
    assert criteria([1, 1, 1]) == (False, None)
    assert criteria([1, 1, 1, 1]) == (False, None)
    assert criteria([1, 1, 1, 1, 1]) == (True, "length")


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def test_weight_hub_files():
    filenames = weight_hub_files("bigscience/bloom-560m")
    assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
    filenames = weight_hub_files("bigscience/bloom")
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]


def test_weight_hub_files_empty():
    filenames = weight_hub_files("bigscience/bloom", ".errors")
    assert filenames == []


def test_download_weights():
    files = download_weights("bigscience/bloom-560m")
    local_files = weight_files("bigscience/bloom-560m")
    assert files == local_files


def test_weight_files_error():
    with pytest.raises(LocalEntryNotFoundError):
        weight_files("bert-base-uncased")