repo_utils.py 8.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for model repo interaction."""

import fnmatch
import json
import os
import time
from collections.abc import Callable
from functools import cache
from pathlib import Path
from typing import TypeVar

import huggingface_hub
15
from huggingface_hub import hf_hub_download, try_to_load_from_cache
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub.utils import (
    EntryNotFoundError,
    HfHubHTTPError,
    LocalEntryNotFoundError,
    RepositoryNotFoundError,
    RevisionNotFoundError,
)

from vllm import envs
from vllm.logger import init_logger

logger = init_logger(__name__)


_R = TypeVar("_R")


def with_retry(
    func: Callable[[], _R],
    log_msg: str,
    max_retries: int = 2,
    retry_delay: int = 2,
) -> _R:
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if attempt == max_retries - 1:
                logger.error("%s: %s", log_msg, e)
                raise
            logger.error(
                "%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries
            )
            time.sleep(retry_delay)
            retry_delay *= 2

    raise AssertionError("Should not be reached")


# @cache doesn't cache exceptions
@cache
def list_repo_files(
    repo_id: str,
    *,
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
) -> list[str]:
    def lookup_files() -> list[str]:
        # directly list files if model is local
        if (local_path := Path(repo_id)).exists():
            return [
                str(file.relative_to(local_path))
                for file in local_path.rglob("*")
                if file.is_file()
            ]
        # if model is remote, use hf_hub api to list files
        try:
            if envs.VLLM_USE_MODELSCOPE:
                from vllm.transformers_utils.utils import modelscope_list_repo_files

                return modelscope_list_repo_files(
                    repo_id,
                    revision=revision,
                    token=os.getenv("MODELSCOPE_API_TOKEN", None),
                )
            return hf_list_repo_files(
                repo_id, revision=revision, repo_type=repo_type, token=token
            )
        except huggingface_hub.errors.OfflineModeIsEnabled:
            # Don't raise in offline mode,
            # all we know is that we don't have this
            # file cached.
            return []

    return with_retry(lookup_files, "Error retrieving file list")


def list_filtered_repo_files(
    model_name_or_path: str,
    allow_patterns: list[str],
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
) -> list[str]:
    try:
        all_files = list_repo_files(
            repo_id=model_name_or_path,
            revision=revision,
            token=token,
            repo_type=repo_type,
        )
    except Exception:
        logger.error(
            "Error retrieving file list. Please ensure your `model_name_or_path`"
            "`repo_type`, `token` and `revision` arguments are correctly set. "
            "Returning an empty list."
        )
        return []

    file_list = []
    # Filter patterns on filenames
    for pattern in allow_patterns:
        file_list.extend(
            [
                file
                for file in all_files
                if fnmatch.fnmatch(os.path.basename(file), pattern)
            ]
        )
    return file_list


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def any_pattern_in_repo_files(
    model_name_or_path: str,
    allow_patterns: list[str],
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
):
    return (
        len(
            list_filtered_repo_files(
                model_name_or_path=model_name_or_path,
                allow_patterns=allow_patterns,
                revision=revision,
                repo_type=repo_type,
                token=token,
            )
        )
        > 0
    )


def is_mistral_model_repo(
    model_name_or_path: str,
    revision: str | None = None,
    repo_type: str | None = None,
    token: str | bool | None = None,
) -> bool:
    return any_pattern_in_repo_files(
        model_name_or_path=model_name_or_path,
        allow_patterns=["consolidated*.safetensors"],
        revision=revision,
        repo_type=repo_type,
        token=token,
    )


166
167
168
169
170
171
172
173
def file_exists(
    repo_id: str,
    file_name: str,
    *,
    repo_type: str | None = None,
    revision: str | None = None,
    token: str | bool | None = None,
) -> bool:
174
175
    # `list_repo_files` is cached and retried on error, so this is more efficient than
    # huggingface_hub.file_exists default implementation when looking for multiple files
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    file_list = list_repo_files(
        repo_id, repo_type=repo_type, revision=revision, token=token
    )
    return file_name in file_list


# In offline mode the result can be a false negative
def file_or_path_exists(
    model: str | Path, config_name: str, revision: str | None
) -> bool:
    if (local_path := Path(model)).exists():
        return (local_path / config_name).is_file()

    # Offline mode support: Check if config file is cached already
    cached_filepath = try_to_load_from_cache(
        repo_id=model, filename=config_name, revision=revision
    )
    if isinstance(cached_filepath, str):
194
        # The config file exists in cache - we can continue trying to load
195
196
197
198
199
200
        return True

    # NB: file_exists will only check for the existence of the config file on
    # hf_hub. This will fail in offline mode.

    # Call HF to check if the file exists
201
    return file_exists(str(model), config_name, revision=revision)
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


def get_model_path(model: str | Path, revision: str | None = None):
    if os.path.exists(model):
        return model
    assert huggingface_hub.constants.HF_HUB_OFFLINE
    common_kwargs = {
        "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE,
        "revision": revision,
    }

    if envs.VLLM_USE_MODELSCOPE:
        from modelscope.hub.snapshot_download import snapshot_download

        return snapshot_download(model_id=model, **common_kwargs)

    from huggingface_hub import snapshot_download

    return snapshot_download(repo_id=model, **common_kwargs)


def get_hf_file_bytes(
    file_name: str, model: str | Path, revision: str | None = "main"
) -> bytes | None:
    """Get file contents from HuggingFace repository as bytes."""
    file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)

    if file_path is None:
230
        hf_hub_file = hf_hub_download(model, file_name, revision=revision)
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        file_path = Path(hf_hub_file)

    if file_path is not None and file_path.is_file():
        with open(file_path, "rb") as file:
            return file.read()

    return None


def try_get_local_file(
    model: str | Path, file_name: str, revision: str | None = "main"
) -> Path | None:
    file_path = Path(model) / file_name
    if file_path.is_file():
        return file_path
    else:
        try:
            cached_filepath = try_to_load_from_cache(
                repo_id=model, filename=file_name, revision=revision
            )
            if isinstance(cached_filepath, str):
                return Path(cached_filepath)
        except ValueError:
            ...
    return None


def get_hf_file_to_dict(
    file_name: str, model: str | Path, revision: str | None = "main"
):
    """
    Downloads a file from the Hugging Face Hub and returns
    its contents as a dictionary.

    Parameters:
    - file_name (str): The name of the file to download.
    - model (str): The name of the model on the Hugging Face Hub.
    - revision (str): The specific version of the model.

    Returns:
    - config_dict (dict): A dictionary containing
    the contents of the downloaded file.
    """

    file_path = try_get_local_file(model=model, file_name=file_name, revision=revision)

    if file_path is None:
        try:
            hf_hub_file = hf_hub_download(model, file_name, revision=revision)
        except huggingface_hub.errors.OfflineModeIsEnabled:
            return None
        except (
            RepositoryNotFoundError,
            RevisionNotFoundError,
            EntryNotFoundError,
            LocalEntryNotFoundError,
        ) as e:
            logger.debug("File or repository not found in hf_hub_download", e)
            return None
        except HfHubHTTPError as e:
            logger.warning(
                "Cannot connect to Hugging Face Hub. Skipping file download for '%s':",
                file_name,
                exc_info=e,
            )
            return None
        file_path = Path(hf_hub_file)

    if file_path is not None and file_path.is_file():
        with open(file_path) as file:
            return json.load(file)

    return None