weight_utils.py 3.8 KB
Newer Older
1
import filelock
2
3
import glob
import json
4
5
import os
from typing import Iterator, List, Optional, Tuple
6

7
from huggingface_hub import snapshot_download
8
import numpy as np
9
import torch
10
from tqdm.auto import tqdm
11

12
13

class Disabledtqdm(tqdm):
14

15
16
17
18
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, disable=True)


19
20
21
22
23
def hf_model_weights_iterator(
    model_name_or_path: str,
    cache_dir: Optional[str] = None,
    use_np_cache: bool = False,
) -> Iterator[Tuple[str, torch.Tensor]]:
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    # Prepare file lock directory to prevent multiple processes from
    # downloading the same model weights at the same time.
    lock_dir = cache_dir if cache_dir is not None else "/tmp"
    lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
    lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))

    # Download model weights from huggingface.
    is_local = os.path.isdir(model_name_or_path)
    if not is_local:
        with lock:
            hf_folder = snapshot_download(model_name_or_path,
                                          allow_patterns="*.bin",
                                          cache_dir=cache_dir,
                                          tqdm_class=Disabledtqdm)
    else:
        hf_folder = model_name_or_path

    hf_bin_files = glob.glob(os.path.join(hf_folder, "*.bin"))

    if use_np_cache:
        # Convert the model weights from torch tensors to numpy arrays for
        # faster loading.
        np_folder = os.path.join(hf_folder, 'np')
        os.makedirs(np_folder, exist_ok=True)
        weight_names_file = os.path.join(np_folder, 'weight_names.json')
        with lock:
            if not os.path.exists(weight_names_file):
                weight_names = []
                for bin_file in hf_bin_files:
                    state = torch.load(bin_file, map_location="cpu")
                    for name, param in state.items():
                        param_path = os.path.join(np_folder, name)
                        with open(param_path, "wb") as f:
                            np.save(f, param.cpu().detach().numpy())
                        weight_names.append(name)
                with open(weight_names_file, 'w') as f:
                    json.dump(weight_names, f)

        with open(weight_names_file, 'r') as f:
            weight_names = json.load(f)

        for name in weight_names:
            param_path = os.path.join(np_folder, name)
            with open(param_path, "rb") as f:
                param = np.load(f)
            yield name, torch.from_numpy(param)
    else:
        for bin_file in hf_bin_files:
            state = torch.load(bin_file, map_location="cpu")
            for name, param in state.items():
                yield name, param


77
78
79
80
81
82
83
84
def load_tensor_parallel_weights(
    param: torch.Tensor,
    loaded_weight: torch.Tensor,
    param_name: str,
    column_parallel_weight_names: List[str],
    row_parallel_weight_names: List[str],
    tensor_model_parallel_rank: int,
) -> None:
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    for p in column_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[0]
            loaded_weight = loaded_weight[
                shard_size * tensor_model_parallel_rank
                :shard_size * (tensor_model_parallel_rank + 1)]
            break
    for p in row_parallel_weight_names:
        if p in param_name:
            shard_size = param.shape[1]
            loaded_weight = loaded_weight[
                :,
                shard_size * tensor_model_parallel_rank
                :shard_size * (tensor_model_parallel_rank + 1)]
            break
    assert param.shape == loaded_weight.shape
    param.data.copy_(loaded_weight)
102
103
104
105
106
107
108
109
110


def initialize_dummy_weights(
    model: torch.nn.Module,
    low: float = -1e-3,
    high: float = 1e-3,
) -> None:
    for param in model.state_dict().values():
        param.data.uniform_(low, high)