"vllm/vscode:/vscode.git/clone" did not exist on "0a2f4c0793988d3cf0d47b5f771fb38231db4b2b"
weight_utils.py 11.1 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
import glob
4
import fnmatch
5
import json
6
import os
JFDuan's avatar
JFDuan committed
7
from collections import defaultdict
8
from typing import Any, Iterator, List, Optional, Tuple
9

10
from huggingface_hub import snapshot_download, HfFileSystem
11
import numpy as np
12
from safetensors.torch import load_file, save_file, safe_open
13
import torch
14
from tqdm.auto import tqdm
15

16
from vllm.config import ModelConfig
JFDuan's avatar
JFDuan committed
17
from vllm.logger import init_logger
18
19
from vllm.model_executor.layers.quantization import (get_quantization_config,
                                                     QuantizationConfig)
JFDuan's avatar
JFDuan committed
20
21
22

logger = init_logger(__name__)

23
24
25
_xdg_cache_home = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
_vllm_filelocks_path = os.path.join(_xdg_cache_home, 'vllm/locks/')

26
27

class Disabledtqdm(tqdm):
28

29
30
31
32
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
33
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
34
35
    lock_dir = cache_dir if cache_dir is not None else _vllm_filelocks_path
    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
36
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
37
    lock = filelock.SoftFileLock(os.path.join(lock_dir, lock_file_name))
JFDuan's avatar
JFDuan committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
55
) -> None:
JFDuan's avatar
JFDuan committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


89
# TODO(woosuk): Move this to other place.
90
91
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
    quant_cls = get_quantization_config(model_config.quantization)
92
    # Read the quantization config from the HF model config, if available.
93
94
    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
                              None)
95
96
    if hf_quant_config is not None:
        return quant_cls.from_config(hf_quant_config)
97
    model_name_or_path = model_config.model
98
99
100
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
101
        with get_lock(model_name_or_path, model_config.download_dir):
102
            hf_folder = snapshot_download(model_name_or_path,
103
                                          revision=model_config.revision,
104
                                          allow_patterns="*.json",
105
                                          cache_dir=model_config.download_dir,
106
107
108
109
110
111
112
113
114
115
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_config_files = [
        f for f in config_files if any(
            f.endswith(x) for x in quant_cls.get_config_filenames())
    ]
    if len(quant_config_files) == 0:
116
117
        raise ValueError(
            f"Cannot find the config file for {model_config.quantization}")
118
    if len(quant_config_files) > 1:
119
120
121
        raise ValueError(
            f"Found multiple config files for {model_config.quantization}: "
            f"{quant_config_files}")
122
123
124
125
126
127
128

    quant_config_file = quant_config_files[0]
    with open(quant_config_file, "r") as f:
        config = json.load(f)
    return quant_cls.from_config(config)


JFDuan's avatar
JFDuan committed
129
130
131
def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
Roy's avatar
Roy committed
132
    load_format: str = "auto",
133
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
134
    revision: Optional[str] = None,
135
) -> Tuple[str, List[str], bool]:
136
137
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
Roy's avatar
Roy committed
138
    use_safetensors = False
139
    # Some quantized models use .pt files for storing the weights.
Roy's avatar
Roy committed
140
141
142
143
144
145
146
147
148
149
150
151
152
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
153
        allow_patterns += ["*.pt"]
Roy's avatar
Roy committed
154

155
    if not is_local:
156
157
158
159
160
161
162
163
164
165
166
        # Before we download we look at that is available:
        fs = HfFileSystem()
        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

        # depending on what is available we download different things
        for pattern in allow_patterns:
            matching = fnmatch.filter(file_list, pattern)
            if len(matching) > 0:
                allow_patterns = [pattern]
                break

167
        logger.info(f"Using model weights format {allow_patterns}")
JFDuan's avatar
JFDuan committed
168
169
170
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
171
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
172
                                          allow_patterns=allow_patterns,
173
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
174
175
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
176
177
    else:
        hf_folder = model_name_or_path
178
179
180
    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
Roy's avatar
Roy committed
181
        if len(hf_weights_files) > 0:
182
183
            if pattern == "*.safetensors":
                use_safetensors = True
Roy's avatar
Roy committed
184
            break
185
    if not use_safetensors:
186
187
188
189
190
191
192
193
194
        # Exclude files that are not needed for inference.
        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
        blacklist = [
            "training_args.bin",
            "optimizer.bin",
            "optimizer.pt",
            "scheduler.pt",
            "scaler.pt",
        ]
JFDuan's avatar
JFDuan committed
195
        hf_weights_files = [
196
197
            f for f in hf_weights_files
            if not any(f.endswith(x) for x in blacklist)
JFDuan's avatar
JFDuan committed
198
199
        ]

200
201
202
203
204
    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
205

JFDuan's avatar
JFDuan committed
206
207
208
209

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
210
    load_format: str = "auto",
Jasmond L's avatar
Jasmond L committed
211
    revision: Optional[str] = None,
Roy's avatar
Roy committed
212
    fall_back_to_pt: Optional[bool] = True,
JFDuan's avatar
JFDuan committed
213
) -> Iterator[Tuple[str, torch.Tensor]]:
214
215
216
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
Roy's avatar
Roy committed
217
        load_format=load_format,
Jasmond L's avatar
Jasmond L committed
218
219
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
220

Roy's avatar
Roy committed
221
    if load_format == "npcache":
JFDuan's avatar
JFDuan committed
222
        # Currently np_cache only support *.bin checkpoints
223
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
224

225
226
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
227
        np_folder = os.path.join(hf_folder, "np")
228
        os.makedirs(np_folder, exist_ok=True)
229
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
230
231
232
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
233
234
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
235
                for bin_file in hf_weights_files:
236
237
238
239
240
241
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
242
                with open(weight_names_file, "w") as f:
243
244
                    json.dump(weight_names, f)

245
        with open(weight_names_file, "r") as f:
246
247
248
249
250
251
252
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
253
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
254
255
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
256
                for name in f.keys():  # noqa: SIM118
twaka's avatar
twaka committed
257
258
                    param = f.get_tensor(name)
                    yield name, param
259
    else:
JFDuan's avatar
JFDuan committed
260
        for bin_file in hf_weights_files:
261
262
263
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
264
265
            del state
            torch.cuda.empty_cache()
266
267


268
269
270
271
272
273
274
275
276
277
278
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
twaka's avatar
twaka committed
279
        x = x[:]
280
281
282
    return x


283
284
285
286
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    assert param.size() == loaded_weight.size()
287
    param.data.copy_(loaded_weight)
288
289
290
291
292
293
294


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
295
296
297
298
299
300
301
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
302
    for param in model.state_dict().values():
CHU Tianxiang's avatar
CHU Tianxiang committed
303
304
        if torch.is_floating_point(param):
            param.data.uniform_(low, high)