test_oneF_oneB.py 5.27 KB
Newer Older
1
2
3
4
5
6
import copy
from functools import partial
from types import MethodType

import pytest
import torch
7
import torch.distributed as dist
8
9
10
11
12
13
14
15
16
17
import torch.nn as nn

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all

18
19
20
DIM = 8
NUM_LAYER = 8

21
22
23

class MlpModel(nn.Module):
    def __init__(self):
24
25
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
26
27

    def forward(self, x):
28
29
        for layer in self.layers:
            x = layer(x)
30
31
32
        return x


33
def pp_linear_fwd(
34
35
36
37
    forward,
    data: torch.Tensor = None,
    input_obj: torch.Tensor = None,
    stage_mgr: PipelineStageManager = None,
38
):
39
    if stage_mgr.is_first_stage():
40
        return {"input_obj": forward(data)}
41
42
43
    elif stage_mgr.is_last_stage():
        return forward(input_obj)
    else:
44
        return {"input_obj": forward(input_obj)}
45
46


47
def examine_pp(num_microbatch: int, batch_size: int):
48
49
50
51
    """
    This test is to examine the correctness of 1F1B, compared with torch.
    Be aware it contains some hardcodes.
    """
52
53
    world_size = dist.get_world_size()
    dist.get_rank()
54
55
56
57
58
59
60
    seed_all(1453)

    # create models
    torch_model = MlpModel().cuda()

    pp_model = copy.deepcopy(torch_model).cuda()

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    pg_mesh = ProcessGroupMesh(world_size)
    stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
    schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)

    rank = dist.get_rank()
    sharded_model = torch.nn.ModuleList()
    num_local_layer = NUM_LAYER // world_size
    for idx, sub_model in enumerate(pp_model.layers):
        if idx // num_local_layer == rank:
            sharded_model.append(sub_model.cuda())
    assert len(sharded_model) == num_local_layer

    def custom_fwd(self, x):
        for layer in self._modules.values():
            x = layer(x)
        return x
77

78
79
80
81
82
83
84
85
    sharded_model._forward = MethodType(custom_fwd, sharded_model)
    sharded_model.forward = MethodType(
        partial(
            pp_linear_fwd,
            stage_mgr=stage_manager,
        ),
        sharded_model._forward,
    )
86
87
88
89
90
91
92

    # create optimizer
    torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
    pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))

    # create
    seed_all(1453)
93
94
    input_list = [torch.rand(batch_size, DIM).cuda()]
    dist.all_reduce(input_list[0])
95

96
    criterion = lambda x, *arg, **kwargs: (x * x).mean()
97
98
99

    # forward and backward
    torch_output = torch_model(input_list[0])
100
    torch_loss = criterion(torch_output)
101
    torch_loss.backward()
102
103
104
    pp_ret = schedule.forward_backward_step(
        sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
    )
105
106
107

    # check loss
    if stage_manager.is_last_stage():
108
        assert torch.allclose(torch_loss, pp_ret["loss"])
109
110

    # check gradients
111
112
113
114
    for i in range(len(sharded_model)):
        idx = rank * num_local_layer + i
        assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
        assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
115
116
117
118

    # step
    torch_optimizer.step()
    pp_optimizer.step()
119
    pp_optimizer.zero_grad()
120
121

    # check updated param
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    for i in range(len(sharded_model)):
        idx = rank * num_local_layer + i
        assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
        assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)

    # forward only
    with torch.no_grad():
        torch_output = torch_model(input_list[0])
        torch_loss = criterion(torch_output)

        pp_ret = schedule.forward_backward_step(
            sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
        )
        if stage_manager.is_last_stage():
            assert torch.allclose(torch_loss, pp_ret["loss"])

        for layer in sharded_model:
            if layer.weight.grad is None:
                assert layer.weight.grad is None and layer.bias.grad is None
            else:
                assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
                assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))


def run_dist(
    rank: int,
    world_size: int,
    port: int,
    num_microbatch: int,
    batch_size: int,
):
153
    colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
154
    examine_pp(num_microbatch, batch_size)
155
156
157


@pytest.mark.dist
158
@pytest.mark.parametrize("num_microbatch", [4, 6])
159
160
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
161
@rerun_if_address_is_in_use()
162
163
164
165
166
167
168
169
def test_pp(num_microbatch: int, batch_size: int, world_size: int):
    assert NUM_LAYER % world_size == 0
    spawn(
        run_dist,
        world_size,
        num_microbatch=num_microbatch,
        batch_size=batch_size,
    )
170
171


172
if __name__ == "__main__":
173
    test_pp(num_microbatch=4, batch_size=4, world_size=4)