test_colo_checkpoint.py 7.77 KB
Newer Older
1
import os, shutil
2
3
import torch
import pytest
4
from copy import deepcopy
5
6
from functools import partial

7
8
import torch.multiprocessing as mp
import torch.distributed as dist
9

10
11
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR
12
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
13
14

import colossalai
15
16
17
18
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
19
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup
20
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
21
from colossalai.nn.optimizer import ColossalaiOptimizer
22

23
from tests.components_to_test.registry import non_distributed_component_funcs
24
25


26
27
28
29
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
30
31


32
33
34
35
def init_1d_col_linear(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
36
37


38
39
40
41
def init_1d_row_embedding(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
42
43


44
45
46
47
def init_1d_col_embedding(weight, pg):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
48
49
50


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
51
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
52
53
54
55
56
57
58
59
60
61
62
    for name, p in model.named_parameters():
        if not isinstance(p, ColoTensor):
            continue
        if 'embed' in name and 'weight' in name:
            init_1d_col_embedding(p, pg)
        if 'proj1' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
        if 'proj2' in name and 'weight' in name:
            init_1d_row_linear(p, pg)
        if 'classifier' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
63
64
65


def check_param_equal(model, torch_model):
66
67
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
        assert torch.all(p.data == tp.data), "{} went wrong.\n {} vs {}\n{}".format(n, p, tp, p.shape)
68
69
70
71
72
73
74
75
76
77
78
79


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


80
81
82
83
84
85
86
def compare_optims(optim1, optim2):
    state1 = optim1.state_dict()['state']
    state2 = optim2.state_dict()['state']
    for k, p1 in state1.items():
        if k not in state2:
            continue
        p2 = state2[k]
87
88
89
90
91
92
93
        for n, t1 in p1.items():
            if n not in p2:
                continue
            t2 = p2[n]
            if isinstance(t1, ColoTensor):
                assert isinstance(t2, ColoTensor)
                assert torch.allclose(t1, t2, rtol=0, atol=0)
94
95


96
97
98
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
99

100
101
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
102

103
    # set_seed(1)
104
    with ColoInitContext(device=get_current_device()):
105
        model = model_builder(checkpoint=True)
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    if use_mp_reload:
        if 'bert' == model_name:
            for name, p in model.named_parameters():
                if not isinstance(p, ColoTensor):
                    continue
                # num_class = type_vocab_size = 2 | (8, 2)
                if 'classifier' in name and 'weight' in name:
                    init_1d_row_linear(p, pg)
                # num_class = vocab_size = 30524 | (30524, 8)
                elif 'word_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = seq_len = 512 | (512, 8)
                elif 'position_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = type_vocab_size = 2 | (2, 8)
                elif 'token_type_embeddings' in name and 'weight' in name:
                    init_1d_col_embedding(p, pg)
                elif p.process_group.tp_world_size() == 1:
125
                    p.set_process_group(pg)
126
127
        elif "simple_net" == model_name:
            init_spec_func(model, pg)
128

129
    model_reload = deepcopy(model)
130
    model = model.cuda()
131
    model.eval()
132

133
    model_reload = model_reload.cuda()
134
    model_reload.eval()
135

136
137
138
    opt_class = torch.optim.Adam
    colo_optimizer = ColossalaiOptimizer(opt_class(model.parameters(), lr=0.1))
    colo_optimizer_reload = ColossalaiOptimizer(opt_class(model_reload.parameters(), lr=0.1))
139
140

    for i, (data, label) in enumerate(train_dataloader):
141

142
143
        # Zero grad
        colo_optimizer.zero_grad()
144
        colo_optimizer_reload.zero_grad()
145
146
147
148
149
150
151

        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
152
            output_reload = model_reload(data)
153
            loss = criterion(output, label)
154
            loss_reload = criterion(output_reload, label)
155
        else:
156
157
            loss = model(data, label)
            loss_reload = model_reload(data, label)
158
159

        loss.backward()
160
161
        loss_reload.backward()

162
163
        colo_optimizer.step()
        colo_optimizer_reload.step()
164
165
166
167
168
169

        if i > 2:
            break

    if not os.path.isdir('./checkpoint') and rank == 0:
        os.mkdir('./checkpoint')
170
171
    dist.barrier()

172
173
    save_checkpoint('./checkpoint', 0, model, colo_optimizer, None)
    load_checkpoint('./checkpoint', 0, model_reload, colo_optimizer_reload, None)
174
175

    check_param_equal(model, model_reload)
176
    compare_optims(colo_optimizer, colo_optimizer_reload)
177

178
179
    if rank == 0:
        remove('./checkpoint')
180
    dist.barrier()
181

182
183

def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
184
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
185
    pg = ProcessGroup(tp_degree=world_size)
186
187
188
    # TODO(haichen) add BERT in the test
    # the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
    for model_name in ['simple_net']:
189
190
191
192
193
194
        _run_checkpoint(model_name,
                        init_1d_row_for_linear_weight_spec,
                        use_ddp,
                        use_mp_reload,
                        test_scheduler=test_scheduler,
                        pg=pg)
195
196
197


@pytest.mark.dist
198
@pytest.mark.parametrize('world_size', [1, 2])
199
@pytest.mark.parametrize('use_ddp', [False])
200
@pytest.mark.parametrize('use_mp_reload', [True, False])
201
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
202
@rerun_if_address_is_in_use()
203
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
204
205
206
207
    run_func = partial(run_dist,
                       world_size=world_size,
                       port=free_port(),
                       use_ddp=use_ddp,
208
                       use_mp_reload=use_mp_reload,
209
                       test_scheduler=test_scheduler)
210
211
212
213
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
214
    test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")