mixtral_checkpoint.py 9.14 KB
Newer Older
Xuanlei Zhao's avatar
Xuanlei Zhao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import logging
import os
from pathlib import Path

import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.moe import MoECheckpintIO
from colossalai.tensor.moe_tensor.api import get_dp_rank, get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor


class MixtralMoECheckpointIO(MoECheckpintIO):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @torch.no_grad()
    def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
        """
        Preprocess state_dict before loading and slice the state_dict of MOE tensors.
        """
        model_param_dict = dict(model.named_parameters())
        for name, param in list(state_dict.items()):
            if ".gate.weight" in name:
                new_name = "module." + name.replace(".gate.weight", ".gate_weight")
                state_dict[new_name] = state_dict.pop(name)
            elif ".experts." in name:
                # if is moe tensor
                # in our moe module, expert is cat as one tensor
                # but mixtral's experts is not cat
                # we will insert the loaded expert into the position of cat tensor

                # get model param
                str_idx = name.index(".experts.")
                expert_idx = int(name.split(".")[-3])
                if ".w1." in name:
                    model_param_name = name.replace(name[str_idx:], ".experts.wi_gate")
                elif ".w2." in name:
                    model_param_name = name.replace(name[str_idx:], ".experts.wo")
                elif ".w3." in name:
                    model_param_name = name.replace(name[str_idx:], ".experts.wi_up")
                model_param_name = "module." + model_param_name
                # skip for pipeline
                if model_param_name not in model_param_dict:
                    continue
                model_param = model_param_dict[model_param_name]
                assert is_moe_tensor(model_param)
                # get expert range
                ep_rank = get_ep_rank(model_param)
                ep_size = get_ep_size(model_param)
                expert_num = 8 // ep_size
                expert_range = list(range(ep_rank * expert_num, (ep_rank + 1) * expert_num))
                # insert new param
                if expert_idx in expert_range:
                    new_param = model_param
                    new_param[expert_idx - ep_rank * expert_num] = param.transpose(0, 1)
                    state_dict[model_param_name] = new_param
                state_dict.pop(name)
            else:
                new_name = "module." + name
                state_dict[new_name] = state_dict.pop(name)

        dist.barrier()
        return state_dict

    def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
        """
        Load sharded model with the given path to index file of checkpoint folder.

        Args:
            model (nn.Module): The model to be loaded.
            checkpoint_index_file (str): Path to the index file of checkpointing folder.
            strict (bool, optional): For name matching during loading state_dict. Defaults to False.
                                     This argument should be manually set to False since params on same device might be stored in different files.
        """

        # Check whether the checkpoint uses safetensors.
        use_safetensors = False
        if "safetensors" in checkpoint_index_file.name:
            use_safetensors = True

        if use_safetensors and not is_safetensors_available():
            raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")

        # Read checkpoint index file.
        ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
        ckpt_root_path = ckpt_index_file.root_path
        weight_map = ckpt_index_file.weight_map
        strict = False

        # Load params & buffers to model.
        # Keep a record of loaded files so that file will not be repeatedly loaded.
        loaded_file = set()

        def _load(name: str):
            if name not in weight_map:
                raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
            filename = weight_map[name]

            # If this param/buffer has been loaded before, directly return.
            if filename in loaded_file:
                return

            file_path = os.path.join(ckpt_root_path, filename)
            state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
            state_dict = self.pre_load_model(model, state_dict)
            missing_keys = []

            load_state_dict_into_model(
                model,
                state_dict,
                missing_keys=missing_keys,
                strict=strict,
                load_sub_module=True,
            )
            loaded_file.add(filename)

        # Load parameters.
        for name, _ in model.named_parameters():
            name = name.replace("module.", "")
            name = name.replace(".gate_weight", ".gate.weight")
            if ".experts.wi_gate" in name:
                for i in range(8):
                    new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
                    _load(new_name)
            elif ".experts.wi_up" in name:
                for i in range(8):
                    new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
                    _load(new_name)
            elif ".experts.wo" in name:
                for i in range(8):
                    new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
                    _load(new_name)
            else:
                _load(name)

        if self.verbose:
            logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

    @torch.no_grad()
    def pre_save_model(self, model: nn.Module) -> dict:
        torch.cuda.empty_cache()
        state_dict = model.state_dict()
        for name, param in list(model.named_parameters()):
            if ".gate_weight" in name:
                new_name = name.replace(".gate_weight", ".gate.weight")
                state_dict[new_name] = state_dict.pop(name).cpu()
            elif ".experts." in name:
                ep_group = get_ep_group(param)
                ep_rank = get_ep_rank(param)
                ep_size = get_ep_size(param)
                dp_rank = get_dp_rank(param)

                if dp_rank == 0:
                    param = param.data.cuda()
                    all_param = [torch.zeros_like(param) for _ in range(ep_size)]
                    # gather param from every ep rank
                    dist.all_gather(all_param, param, group=ep_group)
                    if ep_rank == 0:
                        all_param = torch.cat(all_param, dim=0)
                        assert all_param.shape[0] == 8
                        for i in range(8):
                            if ".wi_gate" in name:
                                new_name = name.replace(".experts.wi_gate", f".experts.{i}.w1.weight")
                            elif ".wi_up" in name:
                                new_name = name.replace(".experts.wi_up", f".experts.{i}.w3.weight")
                            elif ".wo" in name:
                                new_name = name.replace(".experts.wo", f".experts.{i}.w2.weight")
                            new_name = new_name.replace("module.", "")
                            new_param = all_param[i].transpose(-1, -2)
                            state_dict[new_name] = new_param.cpu()
                        state_dict.pop(name)
            else:
                state_dict[name] = param.cpu()

        for name, param in list(state_dict.items()):
            new_name = name.replace("module.", "")
            state_dict[new_name] = state_dict.pop(name)
        
        torch.cuda.empty_cache()
        if self.pp_size > 1:
            if self.dp_rank == 0:
                # gather state_dict from every pp rank
                # because ckpt is large, we split it into 10 parts
                # and gather them one by one
                new_state_dict = {}
                state_dict_keys = list(state_dict.keys())
                gap_key_num = min(30, len(state_dict_keys))
                gap_keys = (len(state_dict_keys) + gap_key_num - 1) // gap_key_num
                for i in range(gap_key_num):
                    cur_keys = state_dict_keys[i * gap_keys : (i + 1) * gap_keys]
                    cur_state_dict = {}
                    for k in cur_keys:
                        cur_state_dict[k] = state_dict[k]
                    out = [None for _ in range(self.pp_size)]
                    dist.all_gather_object(out, cur_state_dict, group=self.pp_group)
                    if self.pp_rank == 0:
                        for o in out:
                            for k, v in o.items():
                                new_state_dict[k] = v.cpu()
                state_dict = new_state_dict
        dist.barrier()
        return state_dict