test_hub.py 2.85 KB
Newer Older
1
2
3
4
import os
import requests
import tempfile

5
6
import pytest

7
8
9
10
import huggingface_hub.constants
from huggingface_hub import hf_api

import text_generation_server.utils.hub
11
from text_generation_server.utils.hub import (
12
13
14
15
16
17
18
19
20
    weight_hub_files,
    download_weights,
    weight_files,
    EntryNotFoundError,
    LocalEntryNotFoundError,
    RevisionNotFoundError,
)


21
22
23
24
25
26
27
28
29
30
31
32
33
34
@pytest.fixture()
def offline():
    current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
    text_generation_server.utils.hub.HF_HUB_OFFLINE = True
    yield "offline"
    text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value


@pytest.fixture()
def fresh_cache():
    with tempfile.TemporaryDirectory() as d:
        current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
35
        os.environ["HUGGINGFACE_HUB_CACHE"] = d
36
37
        yield
        huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
38
        os.environ["HUGGINGFACE_HUB_CACHE"] = current_value
39
40
41
42
43
44
45
46
47
48
49
        text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value


@pytest.fixture()
def prefetched():
    model_id = "bert-base-uncased"
    huggingface_hub.snapshot_download(
        repo_id=model_id,
        revision="main",
        local_files_only=False,
        repo_type="model",
50
        allow_patterns=["*.safetensors"],
51
52
53
54
55
56
57
58
59
60
61
62
63
    )
    yield model_id


def test_weight_hub_files_offline_error(offline, fresh_cache):
    # If the model is not prefetched then it will raise an error
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("gpt2")


def test_weight_hub_files_offline_ok(prefetched, offline):
    # If the model is prefetched then we should be able to get the weight files from local cache
    filenames = weight_hub_files(prefetched)
64
    assert filenames == ["model.safetensors"]
65
66


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def test_weight_hub_files():
    filenames = weight_hub_files("bigscience/bloom-560m")
    assert filenames == ["model.safetensors"]


def test_weight_hub_files_llm():
    filenames = weight_hub_files("bigscience/bloom")
    assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]


def test_weight_hub_files_empty():
    with pytest.raises(EntryNotFoundError):
        weight_hub_files("bigscience/bloom", extension=".errors")


def test_download_weights():
    model_id = "bigscience/bloom-560m"
    filenames = weight_hub_files(model_id)
    files = download_weights(filenames, model_id)
    local_files = weight_files("bigscience/bloom-560m")
    assert files == local_files


90
def test_weight_files_revision_error():
91
92
    with pytest.raises(RevisionNotFoundError):
        weight_files("bigscience/bloom-560m", revision="error")
93
94
95


def test_weight_files_not_cached_error(fresh_cache):
96
97
    with pytest.raises(LocalEntryNotFoundError):
        weight_files("bert-base-uncased")