prepare_weights.py 4.7 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch

from pathlib import Path
from tqdm import tqdm

MODEL_NAME = "bigscience/bloom"


def match_suffix(text, suffix):
    return text[-len(suffix) :] == suffix


def prepare_weights(hub_path: Path, save_path: Path, tp_world_size: int):
    save_paths = [
        save_path / f"{MODEL_NAME}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
        for tp_rank in range(tp_world_size)
    ]

    if all(save_path.exists() for save_path in save_paths):
        print("Weights are already prepared")
        return

    shards_state_dicts = [{} for _ in range(tp_world_size)]

    for weight_path in tqdm(hub_path.glob("*.bin")):
        state_dict = torch.load(weight_path, map_location="cpu")

        keys = list(state_dict.keys())
        for state_name in keys:
            state = state_dict[state_name]
            if any(
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.query_key_value.weight",
                    "self_attention.query_key_value.bias",
                    "mlp.dense_h_to_4h.weight",
                    "mlp.dense_h_to_4h.bias",
                    "word_embeddings.weight",
                    "lm_head.weight",
                ]
            ):
                output_size = state.shape[0]
                assert output_size % tp_world_size == 0
                block_size = output_size // tp_world_size
                sharded_weights = torch.split(state, block_size, dim=0)
                assert len(sharded_weights) == tp_world_size
                for tp_rank, shard in enumerate(sharded_weights):
                    assert shard.shape[0] == block_size
                    if match_suffix(state_name, "lm_head.weight"):
                        shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
                    else:
                        shards_state_dicts[tp_rank][
                            "transformer." + state_name
                        ] = shard.detach().clone()
            elif any(
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.dense.weight",
                    "mlp.dense_4h_to_h.weight",
                    "lm_head.weight",
                ]
            ):
                input_size = state.shape[1]
                assert input_size % tp_world_size == 0
                block_size = input_size // tp_world_size
                sharded_weights = torch.split(state, block_size, dim=1)
                assert len(sharded_weights) == tp_world_size
                for tp_rank, shard in enumerate(sharded_weights):
                    assert shard.shape[1] == block_size
                    if match_suffix(state_name, "lm_head.weight"):
                        shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
                    else:
                        shards_state_dicts[tp_rank][
                            "transformer." + state_name
                        ] = shard.detach().clone()
            elif any(
                match_suffix(state_name, candidate)
                for candidate in [
                    "self_attention.dense.bias",
                    "mlp.dense_4h_to_h.bias",
                ]
            ):
                shards_state_dicts[0][
                    "transformer." + state_name
                ] = state.detach().clone()
                for tp_rank in range(1, tp_world_size):
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = torch.zeros_like(state)
            else:
                # We duplicate parameters across tp ranks
                for tp_rank in range(tp_world_size):
                    shards_state_dicts[tp_rank][
                        "transformer." + state_name
                    ] = state.detach().clone()

            del state_dict[state_name]  # delete key from state_dict
            del state  # delete tensor

    # we save state_dict
    for tp_rank, (save_path, shard_state_dict) in enumerate(
        zip(save_paths, shards_state_dicts)
    ):
        save_paths.append(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        if save_path.exists():
            print(f"Skipping {save_path} as it already exists")
        else:
            torch.save(shard_state_dict, save_path)

    return save_paths


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser()

    parser.add_argument("--hub-path", required=True, type=str)
    parser.add_argument("--save-path", required=True, type=str)
    parser.add_argument("--world-size", required=True, type=int)
    args = parser.parse_args()

    prepare_weights(Path(args.hub_path), Path(args.save_path), args.world_size)