test_colo_checkpoint.py 7.82 KB
Newer Older
1
2
import os
import shutil
3
from copy import deepcopy
4
5
from functools import partial

6
7
import pytest
import torch
8
import torch.distributed as dist
9
10
import torch.multiprocessing as mp
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR
11
12

import colossalai
13
14
15
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
16
17
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
18
19
20
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
21
from tests.components_to_test.registry import non_distributed_component_funcs
22
23


24
25
26
27
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
28
29


30
31
32
33
def init_1d_col_linear(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
34
35


36
37
38
39
def init_1d_row_embedding(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
40
41


42
43
44
45
def init_1d_col_embedding(weight, pg):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
46
47
48


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
49
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
50
51
52
53
54
55
56
57
58
59
60
    for name, p in model.named_parameters():
        if not isinstance(p, ColoTensor):
            continue
        if 'embed' in name and 'weight' in name:
            init_1d_col_embedding(p, pg)
        if 'proj1' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
        if 'proj2' in name and 'weight' in name:
            init_1d_row_linear(p, pg)
        if 'classifier' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
61
62
63


def check_param_equal(model, torch_model):
64
65
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
        assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
66
67
68
69
70
71
72
73
74
75
76
77


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


78
79
80
81
82
83
84
def compare_optims(optim1, optim2):
    state1 = optim1.state_dict()['state']
    state2 = optim2.state_dict()['state']
    for k, p1 in state1.items():
        if k not in state2:
            continue
        p2 = state2[k]
85
86
87
88
89
90
91
        for n, t1 in p1.items():
            if n not in p2:
                continue
            t2 = p2[n]
            if isinstance(t1, ColoTensor):
                assert isinstance(t2, ColoTensor)
                assert torch.allclose(t1, t2, rtol=0, atol=0)
92
93


94
95
96
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
97

98
99
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
100

101
    # set_seed(1)
102
    with ColoInitContext(device=get_current_device()):
103
        model = model_builder(checkpoint=True)
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    if use_mp_reload:
        if 'bert' == model_name:
            for name, p in model.named_parameters():
                if not isinstance(p, ColoTensor):
                    continue
                # num_class = type_vocab_size = 2 | (8, 2)
                if 'classifier' in name and 'weight' in name:
                    init_1d_row_linear(p, pg)
                # num_class = vocab_size = 30524 | (30524, 8)
                elif 'word_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = seq_len = 512 | (512, 8)
                elif 'position_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = type_vocab_size = 2 | (2, 8)
                elif 'token_type_embeddings' in name and 'weight' in name:
                    init_1d_col_embedding(p, pg)
                elif p.process_group.tp_world_size() == 1:
123
                    p.set_process_group(pg)
124
125
        elif "simple_net" == model_name:
            init_spec_func(model, pg)
126

127
    model_reload = deepcopy(model)
128
    model = model.cuda()
129
    model.eval()
130

131
    model_reload = model_reload.cuda()
132
    model_reload.eval()
133

134
135
136
    opt_class = torch.optim.Adam
    colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
    colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
137
138

    for i, (data, label) in enumerate(train_dataloader):
139

140
141
        # Zero grad
        colo_optimizer.zero_grad()
142
        colo_optimizer_reload.zero_grad()
143
144
145
146

        data = data.to(get_current_device())
        label = label.to(get_current_device())

147
148
149
        dist.broadcast(data, pg.tp_rank_list()[0], pg.tp_process_group())
        dist.broadcast(label, pg.tp_rank_list()[0], pg.tp_process_group())

150
151
152
        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
153
            output_reload = model_reload(data)
154
            loss = criterion(output, label)
155
            loss_reload = criterion(output_reload, label)
156
        else:
157
158
            loss = model(data, label)
            loss_reload = model_reload(data, label)
159
160

        loss.backward()
161
162
        loss_reload.backward()

163
164
        colo_optimizer.step()
        colo_optimizer_reload.step()
165
166
167
168
169
170

        if i > 2:
            break

    if not os.path.isdir('./checkpoint') and rank == 0:
        os.mkdir('./checkpoint')
171
172
    dist.barrier()

173
174
    save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
    load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
175
176

    check_param_equal(model, model_reload)
177
    compare_optims(colo_optimizer, colo_optimizer_reload)
178

179
180
    if rank == 0:
        remove('./checkpoint')
181
    dist.barrier()
182

183
184

def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
185
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
186
    pg = ProcessGroup(tp_degree=world_size)
187

188
    # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
189
    for model_name in ['bert']:
190
191
192
193
194
195
        _run_checkpoint(model_name,
                        init_1d_row_for_linear_weight_spec,
                        use_ddp,
                        use_mp_reload,
                        test_scheduler=test_scheduler,
                        pg=pg)
196
197
198


@pytest.mark.dist
199
@pytest.mark.parametrize('world_size', [1, 2])
200
@pytest.mark.parametrize('use_ddp', [False])
201
@pytest.mark.parametrize('use_mp_reload', [True, False])
202
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
203
@rerun_if_address_is_in_use()
204
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
205
206
207
208
    run_func = partial(run_dist,
                       world_size=world_size,
                       port=free_port(),
                       use_ddp=use_ddp,
209
                       use_mp_reload=use_mp_reload,
210
                       test_scheduler=test_scheduler)
211
212
213
214
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
215
    test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")