test_moe_checkpoint.py 6.53 KB
Newer Older
1
import importlib
2
import os
3
4
import shutil
import sys
5
6
7
8

import pytest
import torch
import torch.distributed as dist
9
from transformers.models.llama import LlamaConfig
10
11

import colossalai
12
from colossalai.accelerator import get_accelerator
13
14
15
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.moe.manager import MOE_MANAGER
16
from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
17

18
19
20
21
22
23
sys.path.append(
    os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
        "examples/language/openmoe",
    )
)
24

25
26
27
OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
28
29


30
def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
31
    input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_accelerator().get_current_device())
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    attention_mask = torch.ones_like(input_ids)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": input_ids,
    }


def run_fwd_bwd(
    model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
):
    model.train()
    if pipeline:
        train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
        is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
        y = booster.execute_pipeline(
            train_dataloader_iter,
            model,
            lambda x, y: x.loss,
            optimizer,
            return_loss=True,
            return_outputs=True,
        )
        # Backward and optimize
        if is_pp_last_stage:
            loss = y["loss"]
    else:
        if criterion:
            y = model(data).logits
            loss = criterion(y)
        else:
            loss = model(data, label)
        loss = loss.float()

        if optimizer is not None:
            optimizer.backward(loss)
        else:
            loss.backward()
    return y


73
74
75
76
77
def get_config():
    config = LlamaConfig(
        vocab_size=300,
        hidden_size=16,
        intermediate_size=32,
78
        num_hidden_layers=2,
79
80
81
82
83
        num_attention_heads=2,
        head_dim=4,
        dropout_rate=0.0,
        hidden_act="swiglu",
    )
84
    set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
85
86
87
88
89
90
    return config


def get_model(parallel):
    config = get_config()
    model = OpenMoeForCausalLM(config)
91
    optim = torch.optim.Adam(model.parameters())
92
93
94

    if parallel == None:
        plugin = MoeHybridParallelPlugin(
95
            precision="bf16",
96
97
            tp_size=1,
            pp_size=1,
98
            zero_stage=2,
99
100
            custom_policy=OpenMoeForCausalLMPolicy(),
        )
101
    elif parallel == "ep":
102
        plugin = MoeHybridParallelPlugin(
103
            precision="bf16",
104
105
106
107
108
            tp_size=1,
            pp_size=1,
            zero_stage=2,
            custom_policy=OpenMoeForCausalLMPolicy(),
        )
109
110
111
112
113
114
115
116
117
    elif parallel == "ep_zero":
        plugin = MoeHybridParallelPlugin(
            precision="bf16",
            tp_size=1,
            pp_size=1,
            zero_stage=2,
            extra_dp_size=2,
            custom_policy=OpenMoeForCausalLMPolicy(),
        )
118
119
    elif parallel == "hybrid":
        plugin = MoeHybridParallelPlugin(
120
            precision="bf16",
121
122
123
124
125
126
127
            tp_size=1,
            pp_size=2,
            zero_stage=1,
            microbatch_size=1,
            custom_policy=OpenMoeForCausalLMPolicy(),
        )
    booster = Booster(plugin=plugin)
128
129
    model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
    return model, booster, optim
130
131


132
def _test_moe_checkpoint(rank, parallel):
133
134
135
136
    if parallel == None:
        MOE_MANAGER.setup(
            parallel=None,
        )
137
138
139
140
141
    elif parallel == "ep":
        MOE_MANAGER.setup(
            parallel="EP",
        )
    elif parallel == "ep_zero":
142
143
        MOE_MANAGER.setup(
            parallel="EP",
144
            max_ep_size=2,
145
146
147
148
149
150
151
152
153
        )
    elif parallel == "hybrid":
        MOE_MANAGER.setup(
            parallel="EP",
            mode="fixed",
            fixed_dp_size=1,
            fixed_ep_size=2,
            fixed_pp_size=2,
        )
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    model1, booster1, optim1 = get_model(parallel)
    model2, booster2, optim2 = get_model(parallel)
    model3, booster3, optim3 = get_model(parallel)

    # param ckpt
    # shard
    booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
    booster2.load_model(model2, "./tmp_ckpt1")
    # unshard
    booster1.save_model(model1, "./tmp_ckpt1.pth")
    booster3.load_model(model3, "./tmp_ckpt1.pth")
    # check
    check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
    check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)

    # optim ckpt
    criterion = lambda x: x.mean()
    data = torch.randint(0, 4, (2, 4)).cuda()
    label = torch.randint(0, 4, (2,)).cuda()
    if parallel == "hybrid":
        kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
175
    else:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        kwargs = {}
    run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
    optim1.step()
    optim1.zero_grad()
    # shard
    booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
    dist.barrier()
    booster2.load_optimizer(optim2, "./tmp_ckpt2")
    # unshard
    booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
    booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
    # check
    check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
    check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
190
191

    if dist.get_rank() == 0:
192
193
194
195
        shutil.rmtree("./tmp_ckpt1")
        shutil.rmtree("./tmp_ckpt2")
        os.remove("./tmp_ckpt1.pth")
        os.remove("./tmp_ckpt2.pth")
196
197


198
def _run_dist(rank, world_size, port, parallel):
199
200
201
202
203
204
205
206
    colossalai.launch(
        config=dict(),
        rank=rank,
        world_size=world_size,
        host="localhost",
        port=port,
        backend="nccl",
    )
207
    _test_moe_checkpoint(rank, parallel)
208
209
210


@pytest.mark.dist
211
@pytest.mark.parametrize("world_size", [4])
212
@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
213
@rerun_if_address_is_in_use()
214
215
def test_moe_checkpoint(world_size, parallel):
    spawn(_run_dist, world_size, parallel=parallel)
216
217


218
if __name__ == "__main__":
219
    test_moe_checkpoint(world_size=4, parallel="hybrid")