weight_utils.py 10.9 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
import glob
4
import fnmatch
5
import json
6
import os
JFDuan's avatar
JFDuan committed
7
from collections import defaultdict
8
from typing import Any, Iterator, List, Optional, Tuple
9

10
from huggingface_hub import snapshot_download, HfFileSystem
11
import numpy as np
12
from safetensors.torch import load_file, save_file, safe_open
13
import torch
14
from tqdm.auto import tqdm
15

16
from vllm.config import ModelConfig
JFDuan's avatar
JFDuan committed
17
from vllm.logger import init_logger
18
19
from vllm.model_executor.layers.quantization import (get_quantization_config,
                                                     QuantizationConfig)
JFDuan's avatar
JFDuan committed
20
21
22

logger = init_logger(__name__)

23
24

class Disabledtqdm(tqdm):
25

26
27
28
29
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
30
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
31
32
33
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
JFDuan's avatar
JFDuan committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
51
) -> None:
JFDuan's avatar
JFDuan committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


85
# TODO(woosuk): Move this to other place.
86
87
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
    quant_cls = get_quantization_config(model_config.quantization)
88
    # Read the quantization config from the HF model config, if available.
89
90
    hf_quant_config = getattr(model_config.hf_config, "quantization_config",
                              None)
91
92
    if hf_quant_config is not None:
        return quant_cls.from_config(hf_quant_config)
93
    model_name_or_path = model_config.model
94
95
96
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
97
        with get_lock(model_name_or_path, model_config.download_dir):
98
            hf_folder = snapshot_download(model_name_or_path,
99
                                          revision=model_config.revision,
100
                                          allow_patterns="*.json",
101
                                          cache_dir=model_config.download_dir,
102
103
104
105
106
107
108
109
110
111
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_config_files = [
        f for f in config_files if any(
            f.endswith(x) for x in quant_cls.get_config_filenames())
    ]
    if len(quant_config_files) == 0:
112
113
        raise ValueError(
            f"Cannot find the config file for {model_config.quantization}")
114
    if len(quant_config_files) > 1:
115
116
117
        raise ValueError(
            f"Found multiple config files for {model_config.quantization}: "
            f"{quant_config_files}")
118
119
120
121
122
123
124

    quant_config_file = quant_config_files[0]
    with open(quant_config_file, "r") as f:
        config = json.load(f)
    return quant_cls.from_config(config)


JFDuan's avatar
JFDuan committed
125
126
127
def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
Roy's avatar
Roy committed
128
    load_format: str = "auto",
129
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
130
    revision: Optional[str] = None,
131
) -> Tuple[str, List[str], bool]:
132
133
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
Roy's avatar
Roy committed
134
    use_safetensors = False
135
    # Some quantized models use .pt files for storing the weights.
Roy's avatar
Roy committed
136
137
138
139
140
141
142
143
144
145
146
147
148
    if load_format == "auto":
        allow_patterns = ["*.safetensors", "*.bin"]
    elif load_format == "safetensors":
        use_safetensors = True
        allow_patterns = ["*.safetensors"]
    elif load_format == "pt":
        allow_patterns = ["*.pt"]
    elif load_format == "npcache":
        allow_patterns = ["*.bin"]
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    if fall_back_to_pt:
149
        allow_patterns += ["*.pt"]
Roy's avatar
Roy committed
150

151
    if not is_local:
152
153
154
155
156
157
158
159
160
161
162
        # Before we download we look at that is available:
        fs = HfFileSystem()
        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

        # depending on what is available we download different things
        for pattern in allow_patterns:
            matching = fnmatch.filter(file_list, pattern)
            if len(matching) > 0:
                allow_patterns = [pattern]
                break

163
        logger.info(f"Using model weights format {allow_patterns}")
JFDuan's avatar
JFDuan committed
164
165
166
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
167
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
168
                                          allow_patterns=allow_patterns,
169
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
170
171
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
172
173
    else:
        hf_folder = model_name_or_path
174
175
176
    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
Roy's avatar
Roy committed
177
        if len(hf_weights_files) > 0:
178
179
            if pattern == "*.safetensors":
                use_safetensors = True
Roy's avatar
Roy committed
180
            break
181
    if not use_safetensors:
182
183
184
185
186
187
188
189
190
        # Exclude files that are not needed for inference.
        # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
        blacklist = [
            "training_args.bin",
            "optimizer.bin",
            "optimizer.pt",
            "scheduler.pt",
            "scaler.pt",
        ]
JFDuan's avatar
JFDuan committed
191
        hf_weights_files = [
192
193
            f for f in hf_weights_files
            if not any(f.endswith(x) for x in blacklist)
JFDuan's avatar
JFDuan committed
194
195
        ]

196
197
198
199
200
    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
201

JFDuan's avatar
JFDuan committed
202
203
204
205

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
206
    load_format: str = "auto",
Jasmond L's avatar
Jasmond L committed
207
    revision: Optional[str] = None,
Roy's avatar
Roy committed
208
    fall_back_to_pt: Optional[bool] = True,
JFDuan's avatar
JFDuan committed
209
) -> Iterator[Tuple[str, torch.Tensor]]:
210
211
212
    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
Roy's avatar
Roy committed
213
        load_format=load_format,
Jasmond L's avatar
Jasmond L committed
214
215
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
216

Roy's avatar
Roy committed
217
    if load_format == "npcache":
JFDuan's avatar
JFDuan committed
218
        # Currently np_cache only support *.bin checkpoints
219
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
220

221
222
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
223
        np_folder = os.path.join(hf_folder, "np")
224
        os.makedirs(np_folder, exist_ok=True)
225
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
226
227
228
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
229
230
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
231
                for bin_file in hf_weights_files:
232
233
234
235
236
237
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
238
                with open(weight_names_file, "w") as f:
239
240
                    json.dump(weight_names, f)

241
        with open(weight_names_file, "r") as f:
242
243
244
245
246
247
248
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
249
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
250
251
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
252
                for name in f.keys():  # noqa: SIM118
twaka's avatar
twaka committed
253
254
                    param = f.get_tensor(name)
                    yield name, param
255
    else:
JFDuan's avatar
JFDuan committed
256
        for bin_file in hf_weights_files:
257
258
259
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
260
261
            del state
            torch.cuda.empty_cache()
262
263


264
265
266
267
268
269
270
271
272
273
274
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
twaka's avatar
twaka committed
275
        x = x[:]
276
277
278
    return x


279
280
281
282
def default_weight_loader(param: torch.Tensor,
                          loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    assert param.size() == loaded_weight.size()
283
    param.data.copy_(loaded_weight)
284
285
286
287
288
289
290


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
291
292
293
294
295
296
297
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
298
    for param in model.state_dict().values():
CHU Tianxiang's avatar
CHU Tianxiang committed
299
300
        if torch.is_floating_point(param):
            param.data.uniform_(low, high)