utils.py 4.96 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
import os
6
import struct
7
from functools import cache
8
9
from os import PathLike
from pathlib import Path
10
from typing import Any
11

12
13
from gguf import GGMLQuantizationType

14
import vllm.envs as envs
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)

19

20
def is_s3(model_or_path: str) -> bool:
21
    return model_or_path.lower().startswith("s3://")
22
23


24
25
26
27
28
29
30
31
def is_gcs(model_or_path: str) -> bool:
    return model_or_path.lower().startswith("gs://")


def is_cloud_storage(model_or_path: str) -> bool:
    return is_s3(model_or_path) or is_gcs(model_or_path)


32
@cache
33
def check_gguf_file(model: str | PathLike) -> bool:
34
35
36
37
38
39
40
    """Check if the file is a GGUF model."""
    model = Path(model)
    if not model.is_file():
        return False
    elif model.suffix == ".gguf":
        return True

Reid's avatar
Reid committed
41
42
43
44
45
46
47
48
    try:
        with model.open("rb") as f:
            header = f.read(4)

        return header == b"GGUF"
    except Exception as e:
        logger.debug("Error reading file %s: %s", model, e)
        return False
49
50


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
@cache
def is_remote_gguf(model: str | Path) -> bool:
    """Check if the model is a remote GGUF model."""
    model = str(model)
    return (
        (not is_cloud_storage(model))
        and (not model.startswith(("http://", "https://")))
        and ("/" in model and ":" in model)
        and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
    )


def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
    """Check if the quant type is a valid GGUF quant type."""
    return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None


def split_remote_gguf(model: str | Path) -> tuple[str, str]:
    """Split the model into repo_id and quant type."""
    model = str(model)
    if is_remote_gguf(model):
        parts = model.rsplit(":", 1)
        return (parts[0], parts[1])
    raise ValueError(
        "Wrong GGUF model or invalid GGUF quant type: %s.\n"
        "- It should be in repo_id:quant_type format.\n"
        "- Valid GGMLQuantizationType values: %s",
        model,
        GGMLQuantizationType._member_names_,
    )


def is_gguf(model: str | Path) -> bool:
    """Check if the model is a GGUF model.

    Args:
        model: Model name, path, or Path object to check.

    Returns:
        True if the model is a GGUF model, False otherwise.
    """
    model = str(model)

    # Check if it's a local GGUF file
    if check_gguf_file(model):
        return True

    # Check if it's a remote GGUF model (repo_id:quant_type format)
    return is_remote_gguf(model)


102
103
def modelscope_list_repo_files(
    repo_id: str,
104
105
    revision: str | None = None,
    token: str | bool | None = None,
106
) -> list[str]:
107
108
    """List files in a modelscope repo."""
    from modelscope.hub.api import HubApi
109

110
    api = HubApi()
111
    api.login(token)
112
113
    # same as huggingface_hub.list_repo_files
    files = [
114
115
116
117
118
        file["Path"]
        for file in api.get_model_files(
            model_id=repo_id, revision=revision, recursive=True
        )
        if file["Type"] == "blob"
119
120
    ]
    return files
121
122


123
def _maybe_json_dict(path: str | PathLike) -> dict[str, str]:
124
125
126
127
128
129
130
    with open(path) as f:
        try:
            return json.loads(f.read())
        except Exception:
            return dict[str, str]()


131
def _maybe_space_split_dict(path: str | PathLike) -> dict[str, str]:
132
133
134
135
136
137
138
139
140
141
142
    parsed_dict = dict[str, str]()
    with open(path) as f:
        for line in f.readlines():
            try:
                model_name, redirect_name = line.strip().split()
                parsed_dict[model_name] = redirect_name
            except Exception:
                pass
    return parsed_dict


143
144
145
146
147
148
149
150
151
@cache
def maybe_model_redirect(model: str) -> str:
    """
    Use model_redirect to redirect the model name to a local folder.

    :param model: hf model name
    :return: maybe redirect to a local folder
    """

152
    model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH
153
154
155
156
157
158
159

    if not model_redirect_path:
        return model

    if not Path(model_redirect_path).exists():
        return model

160
161
162
163
    redirect_dict = _maybe_json_dict(model_redirect_path) or _maybe_space_split_dict(
        model_redirect_path
    )
    if redirect_model := redirect_dict.get(model):
164
165
        logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
        return redirect_model
166
167

    return model
168
169


170
def parse_safetensors_file_metadata(path: str | PathLike) -> dict[str, Any]:
171
    with open(path, "rb") as f:
172
173
        length_of_metadata = struct.unpack("<Q", f.read(8))[0]
        metadata = json.loads(f.read(length_of_metadata).decode("utf-8"))
174
        return metadata
175
176
177
178
179
180
181
182
183
184


def convert_model_repo_to_path(model_repo: str) -> str:
    """When VLLM_USE_MODELSCOPE is True convert a model
    repository string to a Path str."""
    if not envs.VLLM_USE_MODELSCOPE or Path(model_repo).exists():
        return model_repo
    from modelscope.utils.file_utils import get_model_cache_root

    return os.path.join(get_model_cache_root(), model_repo)