test_hub.py 3.01 KB
Newer Older
1
2
3
import os
import tempfile

4
5
import pytest

6
7
8
import huggingface_hub.constants

import text_generation_server.utils.hub
9
from text_generation_server.utils.hub import (
10
11
12
13
14
15
16
17
18
    weight_hub_files,
    download_weights,
    weight_files,
    EntryNotFoundError,
    LocalEntryNotFoundError,
    RevisionNotFoundError,
)


19
20
21
22
23
24
25
26
27
28
29
30
31
32
@pytest.fixture()
def offline():
    current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
    text_generation_server.utils.hub.HF_HUB_OFFLINE = True
    yield "offline"
    text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value


@pytest.fixture()
def fresh_cache():
    with tempfile.TemporaryDirectory() as d:
        current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
33
        os.environ["HUGGINGFACE_HUB_CACHE"] = d
34
35
        yield
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
36
        os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
37
38
39
40
41
42
43
44
45
46
47
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value


@pytest.fixture()
def prefetched():
    model_id = "bert-base-uncased"
    huggingface_hub.snapshot_download(
        repo_id=model_id,
        revision="main",
        local_files_only=False,
        repo_type="model",
48
        allow_patterns=["*.safetensors"],
49
50
51
52
53
54
55
56
57
58
59
60
61
    )
    yield model_id


def test_weight_hub_files_offline_error(offline, fresh_cache):
    # If the model is not prefetched then it will raise an error
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("gpt2")


def test_weight_hub_files_offline_ok(prefetched, offline):
    # If the model is prefetched then we should be able to get the weight files from local cache
    filenames = weight_hub_files(prefetched)
62
63
64
65
66
67
68
69
70
    root = None
    assert len(filenames) == 1
    for f in filenames:
        curroot, filename = os.path.split(f)
        if root is None:
            root = curroot
        else:
            assert root == curroot
        assert filename == "model.safetensors"
71
72


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def test_weight_hub_files():
    filenames = weight_hub_files("bigscience/bloom-560m")
    assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
    filenames = weight_hub_files("bigscience/bloom")
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]


def test_weight_hub_files_empty():
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("bigscience/bloom", extension=".errors")


def test_download_weights():
    model_id = "bigscience/bloom-560m"
    filenames = weight_hub_files(model_id)
    files = download_weights(filenames, model_id)
    local_files = weight_files("bigscience/bloom-560m")
    assert files == local_files


96
def test_weight_files_revision_error():
97
98
    with pytest.raises(RevisionNotFoundError):
        weight_files("bigscience/bloom-560m", revision="error")
99
100
101


def test_weight_files_not_cached_error(fresh_cache):
102
103
    with pytest.raises(LocalEntryNotFoundError):
        weight_files("bert-base-uncased")