weight_utils.py 11.2 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
4
import glob
import json
5
import os
JFDuan's avatar
JFDuan committed
6
from collections import defaultdict
7
from typing import Any, Iterator, List, Optional, Tuple
8

9
from huggingface_hub import snapshot_download
JFDuan's avatar
JFDuan committed
10
from safetensors.torch import load_file, save_file, safe_open
11
import numpy as np
12
import torch
13
from tqdm.auto import tqdm
14

JFDuan's avatar
JFDuan committed
15
from vllm.logger import init_logger
16
17
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
JFDuan's avatar
JFDuan committed
18
19
20

logger = init_logger(__name__)

21
22

class Disabledtqdm(tqdm):
23

24
25
26
27
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
28
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
29
30
31
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
JFDuan's avatar
JFDuan committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
49
) -> None:
JFDuan's avatar
JFDuan committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# TODO(woosuk): Move this to other place.
def get_quant_config(
    quantization: str,
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
) -> QuantizationConfig:
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        # Download the config files.
        with get_lock(model_name_or_path, cache_dir):
            hf_folder = snapshot_download(model_name_or_path,
                                          allow_patterns="*.json",
                                          cache_dir=cache_dir,
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
    config_files = glob.glob(os.path.join(hf_folder, "*.json"))

    quant_cls = get_quant_class(quantization)
    quant_config_files = [
        f for f in config_files if any(
            f.endswith(x) for x in quant_cls.get_config_filenames())
    ]
    if len(quant_config_files) == 0:
        raise ValueError(f"Cannot find the config file for {quantization}")
    if len(quant_config_files) > 1:
        raise ValueError(f"Found multiple config files for {quantization}: "
                         f"{quant_config_files}")

    quant_config_file = quant_config_files[0]
    with open(quant_config_file, "r") as f:
        config = json.load(f)
    return quant_cls.from_config(config)


JFDuan's avatar
JFDuan committed
118
119
120
def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
121
122
    use_safetensors: bool = False,
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
123
    revision: Optional[str] = None,
124
) -> Tuple[str, List[str], bool]:
125
126
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
127
128
129
130
131
    if use_safetensors:
        allow_patterns = ["*.safetensors"]
    else:
        # Some quantized models use .pt files for storing the weights.
        allow_patterns = ["*.bin", "*.pt"]
132
    if not is_local:
JFDuan's avatar
JFDuan committed
133
134
135
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
136
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
137
                                          allow_patterns=allow_patterns,
138
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
139
140
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
141
142
    else:
        hf_folder = model_name_or_path
143
144
145
    hf_weights_files: List[str] = []
    for pattern in allow_patterns:
        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
146
    if not use_safetensors:
JFDuan's avatar
JFDuan committed
147
148
149
150
        hf_weights_files = [
            x for x in hf_weights_files if not x.endswith("training_args.bin")
        ]

151
    if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
JFDuan's avatar
JFDuan committed
152
153
        return prepare_hf_model_weights(model_name_or_path,
                                        cache_dir=cache_dir,
154
                                        use_safetensors=False,
Jasmond L's avatar
Jasmond L committed
155
156
                                        fall_back_to_pt=False,
                                        revision=revision)
157
158
159
160
161
162

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
163

JFDuan's avatar
JFDuan committed
164
165
166
167

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
168
    load_format: str = "auto",
Jasmond L's avatar
Jasmond L committed
169
    revision: Optional[str] = None,
JFDuan's avatar
JFDuan committed
170
) -> Iterator[Tuple[str, torch.Tensor]]:
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    use_safetensors = False
    use_np_cache = False
    fall_back_to_pt = False
    if load_format == "auto":
        use_safetensors = True
        fall_back_to_pt = True
    elif load_format == "safetensors":
        use_safetensors = True
    elif load_format == "pt":
        pass
    elif load_format == "npcache":
        use_np_cache = True
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
        use_safetensors=use_safetensors,
Jasmond L's avatar
Jasmond L committed
190
191
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
192
193

    if use_np_cache:
JFDuan's avatar
JFDuan committed
194
        # Currently np_cache only support *.bin checkpoints
195
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
196

197
198
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
199
        np_folder = os.path.join(hf_folder, "np")
200
        os.makedirs(np_folder, exist_ok=True)
201
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
202
203
204
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
205
206
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
207
                for bin_file in hf_weights_files:
208
209
210
211
212
213
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
214
                with open(weight_names_file, "w") as f:
215
216
                    json.dump(weight_names, f)

217
        with open(weight_names_file, "r") as f:
218
219
220
221
222
223
224
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
225
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
226
227
228
229
230
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
                for name in f.keys():
                    param = f.get_slice(name)
                    yield name, param
231
    else:
JFDuan's avatar
JFDuan committed
232
        for bin_file in hf_weights_files:
233
234
235
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
236
237
            del state
            torch.cuda.empty_cache()
238
239


240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
        x = x[:]
    return x


JFDuan's avatar
JFDuan committed
255
256
257
258
259
260
261
262
263
def load_padded_tensor_parallel_vocab(
    param: torch.Tensor,
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
    tensor_model_parallel_rank: int,
) -> None:
    shard_size = param.shape[0]
    start_idx = tensor_model_parallel_rank * shard_size
    end_idx = (tensor_model_parallel_rank + 1) * shard_size
    loaded_weight = loaded_weight[start_idx:end_idx]
264
    loaded_weight = convert_pyslice_to_tensor(loaded_weight)
JFDuan's avatar
JFDuan committed
265
266
267
    param[:loaded_weight.shape[0]].copy_(loaded_weight)


268
269
def load_tensor_parallel_weights(
    param: torch.Tensor,
JFDuan's avatar
JFDuan committed
270
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
271
272
273
274
275
    param_name: str,
    column_parallel_weight_names: List[str],
    row_parallel_weight_names: List[str],
    tensor_model_parallel_rank: int,
) -> None:
276
277
278
    for p in column_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[0]
279
280
281
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[start_idx:end_idx]
282
283
284
285
            break
    for p in row_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[1]
286
287
288
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[:, start_idx:end_idx]
289
            break
JFDuan's avatar
JFDuan committed
290

291
    loaded_weight = convert_pyslice_to_tensor(loaded_weight)
292
293
294
    assert param.shape == loaded_weight.shape, (
        f"{param_name} shape mismatch between model and checkpoint: "
        f"{param.shape} != {loaded_weight.shape}")
295
    param.data.copy_(loaded_weight)
296
297
298
299
300
301
302


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
303
304
305
306
307
308
309
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
310
311
    for param in model.state_dict().values():
        param.data.uniform_(low, high)