"applications/Colossal-LLaMA-2/train_sft.py" did not exist on "3dbbf83f1c46ae2a3b2947e1a5925c2b8af9f7b1"
test_interleaved.py 5.37 KB
Newer Older
1
2
3
4
5
6
import copy
from functools import partial
from types import MethodType

import pytest
import torch
7
import torch.distributed as dist
8
9
10
11
12
13
14
import torch.nn as nn

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
15
from colossalai.testing import rerun_if_address_is_in_use, spawn
16
17
from colossalai.testing.random import seed_all

18
19
20
NUM_LAYER = 8
DIM = 4

21
22
23

class MlpModel(nn.Module):
    def __init__(self):
24
25
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
26
27

    def forward(self, x):
28
29
        for layer in self.layers:
            x = layer(x)
30
31
32
        return x


33
34
35
36
37
38
39
def pp_linear_fwd(
    forward,
    data: torch.Tensor = None,
    input_obj: torch.Tensor = None,
    stage_mgr: PipelineStageManager = None,
    model_chunk_id: int = None,
):
40
41
42
43
44
45
46
    with stage_mgr.switch_model_chunk_id(model_chunk_id):
        if stage_mgr.is_first_stage():
            return {"input_obj": forward(data)}
        elif stage_mgr.is_last_stage():
            return forward(input_obj)
        else:
            return {"input_obj": forward(input_obj)}
47
48


49
50
51
52
53
54
55
56
def run_pp(
    rank: int,
    world_size: int,
    port: int,
    num_microbatch: int,
    batch_size: int,
    num_model_chunk: int,
):
57
58
59
60
    """
    This test is to examine the correctness of interleaved 1F1B, compared with torch.
    Be aware it contains some hardcodes.
    """
61
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
62
63

    # create model
64
    seed_all(1453)
65
66
67
    torch_model = MlpModel().cuda()
    pp_model = copy.deepcopy(torch_model).cuda()

68
69
70
71
72
73
74
75
76
    pg_mesh = ProcessGroupMesh(world_size)
    stage_manager = PipelineStageManager(
        pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
    )
    schedule = InterleavedSchedule(
        stage_manager=stage_manager,
        num_model_chunks=num_model_chunk,
        num_microbatch=num_microbatch,
    )
77
78

    sharded_model = torch.nn.ModuleList()
79
80
    for idx, sub_model in enumerate(pp_model.layers):
        if idx % world_size == rank:
81
82
            sub_model._forward = sub_model.forward
            sub_model.forward = MethodType(
83
                partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
84
85
                sub_model._forward,
            )
86
            sharded_model.append(sub_model.cuda())
87
    assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"
88
89

    # create optimizer
90
91
    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
    pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))
92

93
94
95
96
    # create data
    seed_all(115)
    input_list = [torch.rand(batch_size, DIM).cuda()]
    dist.all_reduce(input_list[0])
97

98
99
    def criterion(x, *args, **kwargs):
        return (x * x).mean()
100
101
102

    # forward and backward
    torch_output = torch_model(input_list[0])
103
    torch_loss = criterion(torch_output)
104
105
    torch_loss.backward()

106
107
108
    pp_ret = schedule.forward_backward_step(
        sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
    )
109
110

    # check loss
111
    if stage_manager.is_last_stage(ignore_chunk=True):
112
        assert torch.allclose(torch_loss, pp_ret["loss"])
113
114

    # check gradients
115
116
117
118
    for i in range(num_model_chunk):
        idx = world_size * i + rank
        assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
        assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
119
120
121
122

    # step
    torch_optimizer.step()
    pp_optimizer.step()
123
    pp_optimizer.zero_grad()
124
125

    # check updated param
126
127
128
129
    for i in range(num_model_chunk):
        idx = world_size * i + rank
        assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
        assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    # forward only
    with torch.no_grad():
        torch_output = torch_model(input_list[0])
        torch_loss = criterion(torch_output)

        pp_ret = schedule.forward_backward_step(
            sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
        )
        if stage_manager.is_last_stage(ignore_chunk=True):
            assert torch.allclose(torch_loss, pp_ret["loss"])

        for layer in sharded_model:
            if layer.weight.grad is None:
                assert layer.weight.grad is None and layer.bias.grad is None
            else:
                assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
                assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))

149
150

@pytest.mark.dist
151
152
153
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("num_model_chunk", [2, 4])
154
@rerun_if_address_is_in_use()
155
156
157
158
159
160
161
162
163
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
    assert NUM_LAYER % num_model_chunk == 0
    spawn(
        run_pp,
        nprocs=NUM_LAYER // num_model_chunk,
        num_microbatch=num_microbatch,
        batch_size=batch_size,
        num_model_chunk=num_model_chunk,
    )
164
165


166
if __name__ == "__main__":
167
    test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)