prepare_weights.py 6.31 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
import torch
Olivier Dehaene's avatar
Olivier Dehaene committed
2
3
4
import os
import tempfile
import json
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
5

Olivier Dehaene's avatar
Olivier Dehaene committed
6
7
8
from typing import BinaryIO
from joblib import Parallel, delayed
from functools import partial
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
9
10
11
from pathlib import Path
from tqdm import tqdm

Olivier Dehaene's avatar
Olivier Dehaene committed
12
13
from huggingface_hub import hf_hub_url
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
14
15
16


def match_suffix(text, suffix):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
17
    return text[-len(suffix) :] == suffix
Olivier Dehaene's avatar
Olivier Dehaene committed
18
19
20


def http_get(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
21
22
23
24
25
    url: str,
    temp_file: BinaryIO,
    *,
    timeout=10.0,
    max_retries=0,
Olivier Dehaene's avatar
Olivier Dehaene committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
):
    """
    Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
    """
    r = _request_wrapper(
        method="GET",
        url=url,
        stream=True,
        timeout=timeout,
        max_retries=max_retries,
    )
    hf_raise_for_status(r)
    for chunk in r.iter_content(chunk_size=1024):
        if chunk:  # filter out keep-alive new chunks
            temp_file.write(chunk)


def cache_download_url(url: str, root_dir: Path):
    filename = root_dir / url.split("/")[-1]

    if not filename.exists():
        temp_file_manager = partial(
            tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
        )
        with temp_file_manager() as temp_file:
            http_get(url, temp_file)

        os.replace(temp_file.name, filename)
    return filename


Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
57
58
59
def prepare_weights(
    model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
60
    save_paths = [
Olivier Dehaene's avatar
Olivier Dehaene committed
61
        save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
62
63
64
65
66
        for tp_rank in range(tp_world_size)
    ]

    if all(save_path.exists() for save_path in save_paths):
        print("Weights are already prepared")
Olivier Dehaene's avatar
Olivier Dehaene committed
67
        return save_paths
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
68

Olivier Dehaene's avatar
Olivier Dehaene committed
69
70
71
72
    cache_path.mkdir(parents=True, exist_ok=True)
    if model_name == "bigscience/bloom-560m":
        url = hf_hub_url(model_name, filename="pytorch_model.bin")
        cache_download_url(url, cache_path)
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
73

Olivier Dehaene's avatar
Olivier Dehaene committed
74
75
76
77
78
79
80
    elif model_name == "bigscience/bloom":
        url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
        index_path = cache_download_url(url, cache_path)
        with index_path.open("r") as f:
            index = json.load(f)

        # Get unique file names
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
81
82
83
        weight_files = list(
            set([filename for filename in index["weight_map"].values()])
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
84
85
        urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]

Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
86
87
88
        Parallel(n_jobs=5)(
            delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
        )
Olivier Dehaene's avatar
Olivier Dehaene committed
89
90
91
    else:
        raise ValueError(f"Unknown model name: {model_name}")

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
92
93
    shards_state_dicts = [{} for _ in range(tp_world_size)]

Olivier Dehaene's avatar
Olivier Dehaene committed
94
    for weight_path in tqdm(Path(cache_path).glob("*.bin")):
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
95
96
97
98
99
100
        state_dict = torch.load(weight_path, map_location="cpu")

        keys = list(state_dict.keys())
        for state_name in keys:
            state = state_dict[state_name]
            if any(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
101
102
103
104
105
106
107
108
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.query_key_value.weight",
                    "self_attention.query_key_value.bias",
                    "mlp.dense_h_to_4h.weight",
                    "mlp.dense_h_to_4h.bias",
                    "word_embeddings.weight",
                ]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
109
110
111
112
113
114
            ):
                output_size = state.shape[0]
                assert output_size % tp_world_size == 0
                block_size = output_size // tp_world_size
                sharded_weights = torch.split(state, block_size, dim=0)
                assert len(sharded_weights) == tp_world_size
Olivier Dehaene's avatar
Olivier Dehaene committed
115

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
116
                for tp_rank, shard in enumerate(sharded_weights):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
117
118
119
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = shard.detach().clone()
Olivier Dehaene's avatar
Olivier Dehaene committed
120
121
122
123
124
125
126
127
128
129
130

            elif match_suffix(state_name, "lm_head.weight"):
                output_size = state.shape[0]
                assert output_size % tp_world_size == 0
                block_size = output_size // tp_world_size
                sharded_weights = torch.split(state, block_size, dim=0)
                assert len(sharded_weights) == tp_world_size

                for tp_rank, shard in enumerate(sharded_weights):
                    shards_state_dicts[tp_rank][state_name] = shard.detach().clone()

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
131
            elif any(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
132
133
134
135
136
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.dense.weight",
                    "mlp.dense_4h_to_h.weight",
                ]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
137
138
139
140
141
142
143
            ):
                input_size = state.shape[1]
                assert input_size % tp_world_size == 0
                block_size = input_size // tp_world_size
                sharded_weights = torch.split(state, block_size, dim=1)
                assert len(sharded_weights) == tp_world_size
                for tp_rank, shard in enumerate(sharded_weights):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
144
145
146
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = shard.detach().clone()
Olivier Dehaene's avatar
Olivier Dehaene committed
147

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
148
            elif any(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
149
150
151
152
153
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.dense.bias",
                    "mlp.dense_4h_to_h.bias",
                ]
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
154
            ):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
155
156
157
                shards_state_dicts[0][
                    "transformer." + state_name
                ] = state.detach().clone()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
158
                for tp_rank in range(1, tp_world_size):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
159
160
161
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = torch.zeros_like(state)
Olivier Dehaene's avatar
Olivier Dehaene committed
162

Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
163
164
165
            else:
                # We duplicate parameters across tp ranks
                for tp_rank in range(tp_world_size):
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
166
167
168
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = state.detach().clone()
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
169
170
171

            del state_dict[state_name]  # delete key from state_dict
            del state  # delete tensor
Olivier Dehaene's avatar
Olivier Dehaene committed
172
        del state_dict
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
173
174
175

    # we save state_dict
    for tp_rank, (save_path, shard_state_dict) in enumerate(
Olivier Dehaene's avatar
v0.1.0  
Olivier Dehaene committed
176
        zip(save_paths, shards_state_dicts)
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
177
178
179
180
181
182
183
184
185
    ):
        save_paths.append(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        if save_path.exists():
            print(f"Skipping {save_path} as it already exists")
        else:
            torch.save(shard_state_dict, save_path)

    return save_paths