weight_utils.py 15.2 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import fnmatch
3
import glob
4
import hashlib
5
import json
6
import os
JFDuan's avatar
JFDuan committed
7
from collections import defaultdict
8
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
9

10
import filelock
11
import huggingface_hub.constants
12
import numpy as np
13
import torch
14
15
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
16
from tqdm.auto import tqdm
17

18
from vllm.config import ModelConfig
JFDuan's avatar
JFDuan committed
19
from vllm.logger import init_logger
20
21
from vllm.model_executor.layers.quantization import (QuantizationConfig,
                                                     get_quantization_config)
22
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
JFDuan's avatar
JFDuan committed
23
24
25

logger = init_logger(__name__)

26
27
28
29
30
31
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
    'TEMP') or os.environ.get('TMP') or "/tmp/"
32

33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def enable_hf_transfer():
    """automatically activates hf_transfer
    """
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa
            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass


enable_hf_transfer()


49
class Disabledtqdm(tqdm):
50

51
52
53
54
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
55
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
56
    lock_dir = cache_dir or temp_dir
57
    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
58
59
60
61
62
63
64
    model_name = model_name_or_path.replace("/", "-")
    hash_name = hashlib.sha256(model_name.encode()).hexdigest()
    # add hash to avoid conflict with old users' lock files
    lock_file_name = hash_name + model_name + ".lock"
    # mode 0o666 is required for the filelock to be shared across users
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
                             mode=0o666)
JFDuan's avatar
JFDuan committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
82
) -> None:
JFDuan's avatar
JFDuan committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


116
# TODO(woosuk): Move this to other place.
117
118
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
    quant_cls = get_quantization_config(model_config.quantization)
119
    # Read the quantization config from the HF model config, if available.
120
121
    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
                              None)
122
123
    if hf_quant_config is not None:
        return quant_cls.from_config(hf_quant_config)
124
    model_name_or_path = model_config.model
125
126
127
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
128
        with get_lock(model_name_or_path, model_config.download_dir):
129
            hf_folder = snapshot_download(model_name_or_path,
130
                                          revision=model_config.revision,
131
                                          allow_patterns="*.json",
132
                                          cache_dir=model_config.download_dir,
133
134
135
136
137
138
139
140
141
142
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_config_files = [
        f for f in config_files if any(
            f.endswith(x) for x in quant_cls.get_config_filenames())
    ]
    if len(quant_config_files) == 0:
143
144
        raise ValueError(
            f"Cannot find the config file for {model_config.quantization}")
145
    if len(quant_config_files) > 1:
146
147
148
        raise ValueError(
            f"Found multiple config files for {model_config.quantization}: "
            f"{quant_config_files}")
149
150
151
152
153
154
155

    quant_config_file = quant_config_files[0]
    with open(quant_config_file, "r") as f:
        config = json.load(f)
    return quant_cls.from_config(config)


JFDuan's avatar
JFDuan committed
156
157
158
def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
Roy's avatar
Roy committed
159
    load_format: str = "auto",
160
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
161
    revision: Optional[str] = None,
162
) -> Tuple[str, List[str], bool]:
163
    # Download model weights from huggingface.
164
165
    is_local = os.path.isdir(model_name_or_path) \
               and load_format != "tensorizer"
Roy's avatar
Roy committed
166
    use_safetensors = False
167
    # Some quantized models use .pt files for storing the weights.
Roy's avatar
Roy committed
168
169
170
171
172
173
174
175
176
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
177
178
    elif load_format == "tensorizer":
        allow_patterns = ["*.tensors"]
Roy's avatar
Roy committed
179
180
181
182
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
183
        allow_patterns += ["*.pt"]
Roy's avatar
Roy committed
184

185
    if not is_local and load_format != "tensorizer":
186
187
188
189
190
191
192
193
194
195
196
        # Before we download we look at that is available:
        fs = HfFileSystem()
        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

        # depending on what is available we download different things
        for pattern in allow_patterns:
            matching = fnmatch.filter(file_list, pattern)
            if len(matching) > 0:
                allow_patterns = [pattern]
                break

197
        logger.info(f"Using model weights format {allow_patterns}")
JFDuan's avatar
JFDuan committed
198
199
200
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
201
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
202
                                          allow_patterns=allow_patterns,
203
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
204
205
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
206
207
    else:
        hf_folder = model_name_or_path
208
209
210
    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
Roy's avatar
Roy committed
211
        if len(hf_weights_files) > 0:
212
213
            if pattern == "*.safetensors":
                use_safetensors = True
Roy's avatar
Roy committed
214
            break
215
    if not use_safetensors:
216
217
218
219
220
221
222
223
224
        # Exclude files that are not needed for inference.
        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
        blacklist = [
            "training_args.bin",
            "optimizer.bin",
            "optimizer.pt",
            "scheduler.pt",
            "scaler.pt",
        ]
JFDuan's avatar
JFDuan committed
225
        hf_weights_files = [
226
227
            f for f in hf_weights_files
            if not any(f.endswith(x) for x in blacklist)
JFDuan's avatar
JFDuan committed
228
229
        ]

230
231
232
    if load_format == "tensorizer":
        return hf_folder, hf_weights_files, use_safetensors

233
234
235
236
237
    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
238

JFDuan's avatar
JFDuan committed
239
240
241
242

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
243
    load_format: Union[Tuple, str] = "auto",
Jasmond L's avatar
Jasmond L committed
244
    revision: Optional[str] = None,
Roy's avatar
Roy committed
245
    fall_back_to_pt: Optional[bool] = True,
JFDuan's avatar
JFDuan committed
246
) -> Iterator[Tuple[str, torch.Tensor]]:
247
248
249
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
Roy's avatar
Roy committed
250
        load_format=load_format,
Jasmond L's avatar
Jasmond L committed
251
252
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
253

Roy's avatar
Roy committed
254
    if load_format == "npcache":
JFDuan's avatar
JFDuan committed
255
        # Currently np_cache only support *.bin checkpoints
256
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
257

258
259
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
260
        np_folder = os.path.join(hf_folder, "np")
261
        os.makedirs(np_folder, exist_ok=True)
262
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
263
264
265
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
266
267
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
268
                for bin_file in hf_weights_files:
269
270
271
272
273
274
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
275
                with open(weight_names_file, "w") as f:
276
277
                    json.dump(weight_names, f)

278
        with open(weight_names_file, "r") as f:
279
280
281
282
283
284
285
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    elif load_format == "tensorizer":
        from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
                                                           open_stream,
                                                           tensorizer_warning)
        tensorizer_args = load_format.params
        tensorizer_warning(
            "Deserializing HuggingFace models is not optimized for "
            "loading on vLLM, as tensorizer is forced to load to CPU. "
            "Consider deserializing a vLLM model instead for faster "
            "load times. See the examples/tensorize_vllm_model.py example "
            "script for serializing vLLM models.")

        deserializer_args = tensorizer_args.deserializer_params
        stream_params = tensorizer_args.stream_params
        stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
        with TensorDeserializer(stream, **deserializer_args,
                                device="cpu") as state:
            for name, param in state.items():
                yield name, param
        del state
306
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
307
308
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
309
                for name in f.keys():  # noqa: SIM118
twaka's avatar
twaka committed
310
311
                    param = f.get_tensor(name)
                    yield name, param
312
    else:
JFDuan's avatar
JFDuan committed
313
        for bin_file in hf_weights_files:
314
315
316
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
317
318
            del state
            torch.cuda.empty_cache()
319
320


321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def kv_cache_scales_loader(
        filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
        model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
    """
    A simple utility to read in KV cache scaling factors that have been
    previously serialized to disk. Used by the model to populate the appropriate
    KV cache scaling factors. The serialization should represent a dictionary
    whose keys are the TP ranks and values are another dictionary mapping layers
    to their KV cache scaling factors.
    Keep this function in sync with the output of examples/fp8/extract_scales.py
    """
    try:
        with open(filename) as f:
            context = {
                "model_type": model_type,
                "num_hidden_layers": num_hidden_layers,
                "tp_rank": tp_rank,
                "tp_size": tp_size,
            }
            schema_dct = json.load(f)
            schema = QuantParamSchema.model_validate(schema_dct,
                                                     context=context)
            layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
            return layer_scales_map.items()

    except FileNotFoundError:
        logger.error(f"File or directory '{filename}' not found.")
    except json.JSONDecodeError:
        logger.error(f"Error decoding JSON in file '{filename}'.")
    except Exception as e:
        logger.error(f"An error occurred while reading '{filename}': {e}")
    # This section is reached if and only if any of the excepts are hit
    # Return an empty iterable (list) => no KV cache scales are loaded
    # which ultimately defaults to 1.0 scales
    logger.warning("Defaulting to KV cache scaling factors = 1.0 "
                   f"for all layers in TP rank {tp_rank} "
                   "as an error occurred during loading.")
    return []


361
362
363
364
365
366
367
368
369
370
371
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
twaka's avatar
twaka committed
372
        x = x[:]
373
374
375
    return x


376
377
378
379
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    assert param.size() == loaded_weight.size()
380
    param.data.copy_(loaded_weight)
381
382
383
384
385
386
387


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
388
389
390
391
392
393
394
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
395
    for param in model.state_dict().values():
CHU Tianxiang's avatar
CHU Tianxiang committed
396
397
        if torch.is_floating_point(param):
            param.data.uniform_(low, high)