test_colo_checkpoint.py 8.05 KB
Newer Older
1
import os, shutil
2
3
import torch
import pytest
4
5
from functools import partial

6
7
import torch.multiprocessing as mp
import torch.distributed as dist
8

9
10
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR
11
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
12
13

import colossalai
14
15
16
17
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
18
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
19
20
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
21
from colossalai.nn.optimizer import ColossalaiOptimizer
22

23
from tests.components_to_test.registry import non_distributed_component_funcs
24
25


26
27
28
29
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
30
31


32
33
34
35
def init_1d_col_linear(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
36
37


38
39
40
41
def init_1d_row_embedding(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
42
43


44
45
46
47
def init_1d_col_embedding(weight, pg):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
48
49
50


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
51
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
52
53
54
55
56
57
58
59
60
61
62
    for name, p in model.named_parameters():
        if not isinstance(p, ColoTensor):
            continue
        if 'embed' in name and 'weight' in name:
            init_1d_col_embedding(p, pg)
        if 'proj1' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
        if 'proj2' in name and 'weight' in name:
            init_1d_row_linear(p, pg)
        if 'classifier' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


def check_param_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


80
81
82
83
84
85
86
87
88
89
90
91
def compare_optims(optim1, optim2):
    state1 = optim1.state_dict()['state']
    state2 = optim2.state_dict()['state']
    for k, p1 in state1.items():
        if k not in state2:
            continue
        p2 = state2[k]
        if isinstance(p1, ColoTensor):
            assert isinstance(p2, ColoTensor)
            assert torch.allclose(p1.to_replicate_(), p2.to_replicate_(), rtol=1e-3, atol=1e-1)


92
93
94
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
95

96
97
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
98

99
    # set_seed(1)
100
    with ColoInitContext(device=get_current_device()):
101
102
        model = model_builder(checkpoint=True)
        model_reload = model_builder(checkpoint=True)
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    if use_mp_reload:
        if 'bert' == model_name:
            for name, p in model.named_parameters():
                if not isinstance(p, ColoTensor):
                    continue
                # num_class = type_vocab_size = 2 | (8, 2)
                if 'classifier' in name and 'weight' in name:
                    init_1d_row_linear(p, pg)
                # num_class = vocab_size = 30524 | (30524, 8)
                elif 'word_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = seq_len = 512 | (512, 8)
                elif 'position_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = type_vocab_size = 2 | (2, 8)
                elif 'token_type_embeddings' in name and 'weight' in name:
                    init_1d_col_embedding(p, pg)
                elif p.process_group.tp_world_size() == 1:
                    p.redistribute(ReplicaSpec(), pg)
        elif "simple_net" == model_name:
            init_spec_func(model, pg)
125

126
    model = model.cuda()
127
128
    model.train()

129
    model_reload = model_reload.cuda()
130
    model_reload.train()
131

132
133
134
135
    opt_class = torch.optim.Adam
    colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
    colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
    run_reload = False
136
137

    for i, (data, label) in enumerate(train_dataloader):
138

139
140
141
142
143
144
145
146
147
        # Zero grad
        colo_optimizer.zero_grad()

        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
148
            output_reload = model_reload(data)
149
            loss = criterion(output, label)
150
            loss_reload = criterion(output_reload, label)
151
        else:
152
153
            loss = model(data, label)
            loss_reload = model_reload(data, label)
154
155

        loss.backward()
156
157
158
159
160
161
162
163
164
165
166
        loss_reload.backward()

        if run_reload:
            colo_optimizer_reload.zero_grad()
            if criterion:
                output_reload = model_reload(data)
                loss_reload = criterion(output_reload, label)
            else:
                loss_reload = model_reload(data, label)
            loss_reload.backward()
            colo_optimizer_reload.step()
167
168
169
170
171
172

        if i > 2:
            break

    if not os.path.isdir('./checkpoint') and rank == 0:
        os.mkdir('./checkpoint')
173
174
175
    save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
    dist.barrier()
    load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
176
177
178
179
180
181
182
183
184
185
    dist.barrier()

    # Since model is sharded, we merge them before param checking.
    for p in model.parameters():
        p.to_replicate_()

    for p in model_reload.parameters():
        p.to_replicate_()

    check_param_equal(model, model_reload)
186
    compare_optims(colo_optimizer, colo_optimizer_reload)
187
188
189
    if rank == 0:
        remove('./checkpoint')

190
191

def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
192
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
193
    pg = ProcessGroup(tp_degree=world_size)
194
    for model_name in ['simple_net', 'bert']:
195
196
197
198
199
200
        _run_checkpoint(model_name,
                        init_1d_row_for_linear_weight_spec,
                        use_ddp,
                        use_mp_reload,
                        test_scheduler=test_scheduler,
                        pg=pg)
201
202
203


@pytest.mark.dist
204
@pytest.mark.parametrize('world_size', [1, 2])
205
@pytest.mark.parametrize('use_ddp', [False])
206
@pytest.mark.parametrize('use_mp_reload', [True, False])
207
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
208
@rerun_if_address_is_in_use()
209
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
210
211
212
213
    run_func = partial(run_dist,
                       world_size=world_size,
                       port=free_port(),
                       use_ddp=use_ddp,
214
                       use_mp_reload=use_mp_reload,
215
                       test_scheduler=test_scheduler)
216
217
218
219
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
220
    test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")