weight_utils.py 11.6 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import fnmatch
3
import glob
4
import hashlib
5
import json
6
import os
JFDuan's avatar
JFDuan committed
7
from collections import defaultdict
8
from typing import Any, Iterator, List, Optional, Tuple
9

10
import filelock
11
import numpy as np
12
import torch
13
14
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
15
from tqdm.auto import tqdm
16

17
from vllm.config import ModelConfig
JFDuan's avatar
JFDuan committed
18
from vllm.logger import init_logger
19
20
from vllm.model_executor.layers.quantization import (QuantizationConfig,
                                                     get_quantization_config)
JFDuan's avatar
JFDuan committed
21
22
23

logger = init_logger(__name__)

24
25
26
27
28
29
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
    'TEMP') or os.environ.get('TMP') or "/tmp/"
30

31
32

class Disabledtqdm(tqdm):
33

34
35
36
37
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
38
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
39
    lock_dir = cache_dir or temp_dir
40
    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
41
42
43
44
45
46
47
    model_name = model_name_or_path.replace("/", "-")
    hash_name = hashlib.sha256(model_name.encode()).hexdigest()
    # add hash to avoid conflict with old users' lock files
    lock_file_name = hash_name + model_name + ".lock"
    # mode 0o666 is required for the filelock to be shared across users
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
                             mode=0o666)
JFDuan's avatar
JFDuan committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
65
) -> None:
JFDuan's avatar
JFDuan committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


99
# TODO(woosuk): Move this to other place.
100
101
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
    quant_cls = get_quantization_config(model_config.quantization)
102
    # Read the quantization config from the HF model config, if available.
103
104
    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
                              None)
105
106
    if hf_quant_config is not None:
        return quant_cls.from_config(hf_quant_config)
107
    model_name_or_path = model_config.model
108
109
110
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
111
        with get_lock(model_name_or_path, model_config.download_dir):
112
            hf_folder = snapshot_download(model_name_or_path,
113
                                          revision=model_config.revision,
114
                                          allow_patterns="*.json",
115
                                          cache_dir=model_config.download_dir,
116
117
118
119
120
121
122
123
124
125
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_config_files = [
        f for f in config_files if any(
            f.endswith(x) for x in quant_cls.get_config_filenames())
    ]
    if len(quant_config_files) == 0:
126
127
        raise ValueError(
            f"Cannot find the config file for {model_config.quantization}")
128
    if len(quant_config_files) > 1:
129
130
131
        raise ValueError(
            f"Found multiple config files for {model_config.quantization}: "
            f"{quant_config_files}")
132
133
134
135
136
137
138

    quant_config_file = quant_config_files[0]
    with open(quant_config_file, "r") as f:
        config = json.load(f)
    return quant_cls.from_config(config)


JFDuan's avatar
JFDuan committed
139
140
141
def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
Roy's avatar
Roy committed
142
    load_format: str = "auto",
143
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
144
    revision: Optional[str] = None,
145
) -> Tuple[str, List[str], bool]:
146
147
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
Roy's avatar
Roy committed
148
    use_safetensors = False
149
    # Some quantized models use .pt files for storing the weights.
Roy's avatar
Roy committed
150
151
152
153
154
155
156
157
158
159
160
161
162
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
163
        allow_patterns += ["*.pt"]
Roy's avatar
Roy committed
164

165
    if not is_local:
166
167
168
169
170
171
172
173
174
175
176
        # Before we download we look at that is available:
        fs = HfFileSystem()
        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

        # depending on what is available we download different things
        for pattern in allow_patterns:
            matching = fnmatch.filter(file_list, pattern)
            if len(matching) > 0:
                allow_patterns = [pattern]
                break

177
        logger.info(f"Using model weights format {allow_patterns}")
JFDuan's avatar
JFDuan committed
178
179
180
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
181
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
182
                                          allow_patterns=allow_patterns,
183
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
184
185
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
186
187
    else:
        hf_folder = model_name_or_path
188
189
190
    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
Roy's avatar
Roy committed
191
        if len(hf_weights_files) > 0:
192
193
            if pattern == "*.safetensors":
                use_safetensors = True
Roy's avatar
Roy committed
194
            break
195
    if not use_safetensors:
196
197
198
199
200
201
202
203
204
        # Exclude files that are not needed for inference.
        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
        blacklist = [
            "training_args.bin",
            "optimizer.bin",
            "optimizer.pt",
            "scheduler.pt",
            "scaler.pt",
        ]
JFDuan's avatar
JFDuan committed
205
        hf_weights_files = [
206
207
            f for f in hf_weights_files
            if not any(f.endswith(x) for x in blacklist)
JFDuan's avatar
JFDuan committed
208
209
        ]

210
211
212
213
214
    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
215

JFDuan's avatar
JFDuan committed
216
217
218
219

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
220
    load_format: str = "auto",
Jasmond L's avatar
Jasmond L committed
221
    revision: Optional[str] = None,
Roy's avatar
Roy committed
222
    fall_back_to_pt: Optional[bool] = True,
JFDuan's avatar
JFDuan committed
223
) -> Iterator[Tuple[str, torch.Tensor]]:
224
225
226
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
Roy's avatar
Roy committed
227
        load_format=load_format,
Jasmond L's avatar
Jasmond L committed
228
229
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
230

Roy's avatar
Roy committed
231
    if load_format == "npcache":
JFDuan's avatar
JFDuan committed
232
        # Currently np_cache only support *.bin checkpoints
233
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
234

235
236
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
237
        np_folder = os.path.join(hf_folder, "np")
238
        os.makedirs(np_folder, exist_ok=True)
239
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
240
241
242
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
243
244
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
245
                for bin_file in hf_weights_files:
246
247
248
249
250
251
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
252
                with open(weight_names_file, "w") as f:
253
254
                    json.dump(weight_names, f)

255
        with open(weight_names_file, "r") as f:
256
257
258
259
260
261
262
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
263
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
264
265
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
266
                for name in f.keys():  # noqa: SIM118
twaka's avatar
twaka committed
267
268
                    param = f.get_tensor(name)
                    yield name, param
269
    else:
JFDuan's avatar
JFDuan committed
270
        for bin_file in hf_weights_files:
271
272
273
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
274
275
            del state
            torch.cuda.empty_cache()
276
277


278
279
280
281
282
283
284
285
286
287
288
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
twaka's avatar
twaka committed
289
        x = x[:]
290
291
292
    return x


293
294
295
296
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    assert param.size() == loaded_weight.size()
297
    param.data.copy_(loaded_weight)
298
299
300
301
302
303
304


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
305
306
307
308
309
310
311
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
312
    for param in model.state_dict().values():
CHU Tianxiang's avatar
CHU Tianxiang committed
313
314
        if torch.is_floating_point(param):
            param.data.uniform_(low, high)