weight_utils.py 8.3 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
4
import glob
import json
5
import os
JFDuan's avatar
JFDuan committed
6
7
from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
8

9
from huggingface_hub import snapshot_download
JFDuan's avatar
JFDuan committed
10
from safetensors.torch import load_file, save_file, safe_open
11
import numpy as np
12
import torch
13
from tqdm.auto import tqdm
14

JFDuan's avatar
JFDuan committed
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)

19
20

class Disabledtqdm(tqdm):
21

22
23
24
25
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


JFDuan's avatar
JFDuan committed
26
def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
27
28
29
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
JFDuan's avatar
JFDuan committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    return lock


def _shared_pointers(tensors):
    ptrs = defaultdict(list)
    for k, v in tensors.items():
        ptrs[v.data_ptr()].append(k)
    failing = []
    for _, names in ptrs.items():
        if len(names) > 1:
            failing.append(names)
    return failing


def convert_bin_to_safetensor_file(
    pt_filename: str,
    sf_filename: str,
):
    loaded = torch.load(pt_filename, map_location="cpu")
    if "state_dict" in loaded:
        loaded = loaded["state_dict"]
    shared = _shared_pointers(loaded)
    for shared_weights in shared:
        for name in shared_weights[1:]:
            loaded.pop(name)

    # For tensors to be contiguous
    loaded = {k: v.contiguous() for k, v in loaded.items()}

    dirname = os.path.dirname(sf_filename)
    os.makedirs(dirname, exist_ok=True)
    save_file(loaded, sf_filename, metadata={"format": "pt"})

    # check file size
    sf_size = os.stat(sf_filename).st_size
    pt_size = os.stat(pt_filename).st_size
    if (sf_size - pt_size) / pt_size > 0.01:
        raise RuntimeError(f"""The file size different is more than 1%:
         - {sf_filename}: {sf_size}
         - {pt_filename}: {pt_size}
         """)

    # check if the tensors are the same
    reloaded = load_file(sf_filename)
    for k in loaded:
        pt_tensor = loaded[k]
        sf_tensor = reloaded[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


def prepare_hf_model_weights(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
    use_safetensor: bool = False,
):
86
87
    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
JFDuan's avatar
JFDuan committed
88
    allow_patterns = "*.safetensors" if use_safetensor else "*.bin"
89
    if not is_local:
JFDuan's avatar
JFDuan committed
90
91
92
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model_name_or_path, cache_dir):
93
            hf_folder = snapshot_download(model_name_or_path,
JFDuan's avatar
JFDuan committed
94
                                          allow_patterns=allow_patterns,
95
96
97
98
                                          cache_dir=cache_dir,
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path
JFDuan's avatar
JFDuan committed
99
100
101
102
103
104
105
106
107
108
109
110
111
    hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
    if not use_safetensor:
        hf_weights_files = [
            x for x in hf_weights_files if not x.endswith("training_args.bin")
        ]

    if len(hf_weights_files) == 0 and use_safetensor:
        logger.warning("No *.safetensors files found, "
                       "fall back to *.bin files")
        return prepare_hf_model_weights(model_name_or_path,
                                        cache_dir=cache_dir,
                                        use_safetensor=False)
    return hf_folder, hf_weights_files, use_safetensor
112

JFDuan's avatar
JFDuan committed
113
114
115
116
117
118
119
120
121

def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
    use_np_cache: bool = False,
    use_safetensor: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
    hf_folder, hf_weights_files, use_safetensor = prepare_hf_model_weights(
        model_name_or_path, cache_dir=cache_dir, use_safetensor=use_safetensor)
122
123

    if use_np_cache:
JFDuan's avatar
JFDuan committed
124
125
126
        # Currently np_cache only support *.bin checkpoints
        assert use_safetensor is False

127
128
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
129
        np_folder = os.path.join(hf_folder, "np")
130
        os.makedirs(np_folder, exist_ok=True)
131
        weight_names_file = os.path.join(np_folder, "weight_names.json")
JFDuan's avatar
JFDuan committed
132
133
134
        # Use file lock to prevent multiple processes from
        # dumping the same model weights to numpy at the same time.
        with get_lock(model_name_or_path, cache_dir):
135
136
            if not os.path.exists(weight_names_file):
                weight_names = []
JFDuan's avatar
JFDuan committed
137
                for bin_file in hf_weights_files:
138
139
140
141
142
143
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
144
                with open(weight_names_file, "w") as f:
145
146
                    json.dump(weight_names, f)

147
        with open(weight_names_file, "r") as f:
148
149
150
151
152
153
154
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
JFDuan's avatar
JFDuan committed
155
156
157
158
159
160
    elif use_safetensor:
        for st_file in hf_weights_files:
            with safe_open(st_file, framework="pt") as f:
                for name in f.keys():
                    param = f.get_slice(name)
                    yield name, param
161
    else:
JFDuan's avatar
JFDuan committed
162
        for bin_file in hf_weights_files:
163
164
165
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
166
167
            del state
            torch.cuda.empty_cache()
168
169


JFDuan's avatar
JFDuan committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def load_padded_tensor_parallel_vocab(
    param: torch.Tensor,
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
    tensor_model_parallel_rank: int,
) -> None:
    shard_size = param.shape[0]
    start_idx = tensor_model_parallel_rank * shard_size
    end_idx = (tensor_model_parallel_rank + 1) * shard_size
    loaded_weight = loaded_weight[start_idx:end_idx]

    # convert PySafeSlice object to torch.Tensor
    if not isinstance(loaded_weight, torch.Tensor):
        loaded_weight = loaded_weight[:]

    param[:loaded_weight.shape[0]].copy_(loaded_weight)


187
188
def load_tensor_parallel_weights(
    param: torch.Tensor,
JFDuan's avatar
JFDuan committed
189
    loaded_weight: Any,  # `torch.Tensor` or `PySafeSlice`
190
191
192
193
194
    param_name: str,
    column_parallel_weight_names: List[str],
    row_parallel_weight_names: List[str],
    tensor_model_parallel_rank: int,
) -> None:
195
196
197
    for p in column_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[0]
198
199
200
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[start_idx:end_idx]
201
202
203
204
            break
    for p in row_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[1]
205
206
207
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[:, start_idx:end_idx]
208
            break
JFDuan's avatar
JFDuan committed
209
210
211
212
213

    # convert PySafeSlice object to torch.Tensor
    if not isinstance(loaded_weight, torch.Tensor):
        loaded_weight = loaded_weight[:]

214
215
216
    assert param.shape == loaded_weight.shape, (
        f"{param_name} shape mismatch between model and checkpoint: "
        f"{param.shape} != {loaded_weight.shape}")
217
    param.data.copy_(loaded_weight)
218
219
220
221
222
223
224


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
225
226
227
228
229
230
231
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
232
233
    for param in model.state_dict().values():
        param.data.uniform_(low, high)