weight_utils.py 4.5 KB
Newer Older
1
"""Utilities for downloading and initializing model weights."""
2
import filelock
3
4
import glob
import json
5
6
import os
from typing import Iterator, List, Optional, Tuple
7

8
from huggingface_hub import snapshot_download
9
import numpy as np
10
import torch
11
from tqdm.auto import tqdm
12

13
14

class Disabledtqdm(tqdm):
15

16
17
18
19
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


20
21
22
23
24
def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
    use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    # Prepare file lock directory to prevent multiple processes from
    # downloading the same model weights at the same time.
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))

    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        with lock:
            hf_folder = snapshot_download(model_name_or_path,
                                          allow_patterns="*.bin",
                                          cache_dir=cache_dir,
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path

42
43
44
45
    hf_bin_files = [
        x for x in glob.glob(os.path.join(hf_folder, "*.bin"))
        if not x.endswith("training_args.bin")
    ]
46
47
48
49

    if use_np_cache:
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
50
        np_folder = os.path.join(hf_folder, "np")
51
        os.makedirs(np_folder, exist_ok=True)
52
        weight_names_file = os.path.join(np_folder, "weight_names.json")
53
54
55
56
57
58
59
60
61
62
        with lock:
            if not os.path.exists(weight_names_file):
                weight_names = []
                for bin_file in hf_bin_files:
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
63
                with open(weight_names_file, "w") as f:
64
65
                    json.dump(weight_names, f)

66
        with open(weight_names_file, "r") as f:
67
68
69
70
71
72
73
74
75
76
77
78
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
    else:
        for bin_file in hf_bin_files:
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param
Xinyu Yang's avatar
Xinyu Yang committed
79
80
            del state
            torch.cuda.empty_cache()
81
82


83
84
85
86
87
88
89
90
def load_tensor_parallel_weights(
    param: torch.Tensor,
    loaded_weight: torch.Tensor,
    param_name: str,
    column_parallel_weight_names: List[str],
    row_parallel_weight_names: List[str],
    tensor_model_parallel_rank: int,
) -> None:
91
92
93
    for p in column_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[0]
94
95
96
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[start_idx:end_idx]
97
98
99
100
            break
    for p in row_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[1]
101
102
103
            start_idx = tensor_model_parallel_rank * shard_size
            end_idx = (tensor_model_parallel_rank + 1) * shard_size
            loaded_weight = loaded_weight[:, start_idx:end_idx]
104
            break
105
106
107
    assert param.shape == loaded_weight.shape, (
        f"{param_name} shape mismatch between model and checkpoint: "
        f"{param.shape} != {loaded_weight.shape}")
108
    param.data.copy_(loaded_weight)
109
110
111
112
113
114
115


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
116
117
118
119
120
121
122
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.
    """
123
124
    for param in model.state_dict().values():
        param.data.uniform_(low, high)