weight_utils.py 9.51 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
4
import glob
import json
5
import os
JFDuan's avatar
JFDuan committed
6
7
from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
8

9
from huggingface_hub import snapshot_download
JFDuan's avatar
JFDuan committed
10
from safetensors.torch import load_file, save_file, safe_open
11
import numpy as np
12
import torch
13
from tqdm.auto import tqdm
14

JFDuan's avatar
JFDuan committed
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)

19
20

class Disabledtqdm(tqdm):
21

22
23
24
25
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
26
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
27
28
29
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
JFDuan's avatar
JFDuan committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
):
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
84
85
    use_safetensors: bool = False,
    fall_back_to_pt: bool = True,
Jasmond L's avatar
Jasmond L committed
86
    revision: Optional[str] = None,
JFDuan's avatar
JFDuan committed
87
):
88
89
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
90
    allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
91
    if not is_local:
JFDuan's avatar
JFDuan committed
92
93
94
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
95
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
96
                                          allow_patterns=allow_patterns,
97
                                          cache_dir=cache_dir,
Jasmond L's avatar
Jasmond L committed
98
99
                                          tqdm_class=Disabledtqdm,
                                          revision=revision)
100
101
    else:
        hf_folder = model_name_or_path
JFDuan's avatar
JFDuan committed
102
    hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
103
    if not use_safetensors:
JFDuan's avatar
JFDuan committed
104
105
106
107
        hf_weights_files = [
            x for x in hf_weights_files if not x.endswith("training_args.bin")
        ]

108
    if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
JFDuan's avatar
JFDuan committed
109
110
        return prepare_hf_model_weights(model_name_or_path,
                                        cache_dir=cache_dir,
111
                                        use_safetensors=False,
Jasmond L's avatar
Jasmond L committed
112
113
                                        fall_back_to_pt=False,
                                        revision=revision)
114
115
116
117
118
119

    if len(hf_weights_files) == 0:
        raise RuntimeError(
            f"Cannot find any model weights with `{model_name_or_path}`")

    return hf_folder, hf_weights_files, use_safetensors
120

JFDuan's avatar
JFDuan committed
121
122
123
124

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
125
    load_format: str = "auto",
Jasmond L's avatar
Jasmond L committed
126
    revision: Optional[str] = None,
JFDuan's avatar
JFDuan committed
127
) -> Iterator[Tuple[str, torch.Tensor]]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    use_safetensors = False
    use_np_cache = False
    fall_back_to_pt = False
    if load_format == "auto":
        use_safetensors = True
        fall_back_to_pt = True
    elif load_format == "safetensors":
        use_safetensors = True
    elif load_format == "pt":
        pass
    elif load_format == "npcache":
        use_np_cache = True
    else:
        raise ValueError(f"Unknown load_format: {load_format}")

    hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
        model_name_or_path,
        cache_dir=cache_dir,
        use_safetensors=use_safetensors,
Jasmond L's avatar
Jasmond L committed
147
148
        fall_back_to_pt=fall_back_to_pt,
        revision=revision)
149
150

    if use_np_cache:
JFDuan's avatar
JFDuan committed
151
        # Currently np_cache only support *.bin checkpoints
152
        assert use_safetensors is False
JFDuan's avatar
JFDuan committed
153

154
155
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
156
        np_folder = os.path.join(hf_folder, "np")
157
        os.makedirs(np_folder, exist_ok=True)
158
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
159
160
161
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
162
163
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
164
                for bin_file in hf_weights_files:
165
166
167
168
169
170
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
171
                with open(weight_names_file, "w") as f:
172
173
                    json.dump(weight_names, f)

174
        with open(weight_names_file, "r") as f:
175
176
177
178
179
180
181
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
182
    elif use_safetensors:
JFDuan's avatar
JFDuan committed
183
184
185
186
187
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
                for name in f.keys():
                    param = f.get_slice(name)
                    yield name, param
188
    else:
JFDuan's avatar
JFDuan committed
189
        for bin_file in hf_weights_files:
190
191
192
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
193
194
            del state
            torch.cuda.empty_cache()
195
196


197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
        x = x[:]
    return x


JFDuan's avatar
JFDuan committed
212
213
214
215
216
217
218
219
220
def load_padded_tensor_parallel_vocab(
    param: torch.Tensor,
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
    tensor_model_parallel_rank: int,
) -> None:
    shard_size = param.shape[0]
    start_idx = tensor_model_parallel_rank * shard_size
    end_idx = (tensor_model_parallel_rank + 1) * shard_size
    loaded_weight = loaded_weight[start_idx:end_idx]
221
    loaded_weight = convert_pyslice_to_tensor(loaded_weight)
JFDuan's avatar
JFDuan committed
222
223
224
    param[:loaded_weight.shape[0]].copy_(loaded_weight)


225
226
def load_tensor_parallel_weights(
    param: torch.Tensor,
JFDuan's avatar
JFDuan committed
227
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
228
229
230
231
232
    param_name: str,
    column_parallel_weight_names: List[str],
    row_parallel_weight_names: List[str],
    tensor_model_parallel_rank: int,
) -> None:
233
234
235
    for p in column_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[0]
236
237
238
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[start_idx:end_idx]
239
240
241
242
            break
    for p in row_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[1]
243
244
245
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[:, start_idx:end_idx]
246
            break
JFDuan's avatar
JFDuan committed
247

248
    loaded_weight = convert_pyslice_to_tensor(loaded_weight)
249
250
251
    assert param.shape == loaded_weight.shape, (
        f"{param_name} shape mismatch between model and checkpoint: "
        f"{param.shape} != {loaded_weight.shape}")
252
    param.data.copy_(loaded_weight)
253
254
255
256
257
258
259


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
260
261
262
263
264
265
266
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
267
268
    for param in model.state_dict().values():
        param.data.uniform_(low, high)