test_colo_checkpoint.py 6.92 KB
Newer Older
1
import os, shutil
2
3
import torch
import pytest
4
5
from functools import partial

6
7
import torch.multiprocessing as mp
import torch.distributed as dist
8

9
10
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR
11
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
12
13

import colossalai
14
15
16
17
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
18
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, DistSpecManager, ReplicaSpec
19
20
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
21
from colossalai.nn.optimizer import ColoOptimizer
22

23
from tests.components_to_test.registry import non_distributed_component_funcs
24
25


26
27
28
29
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
30
31


32
33
34
35
def init_1d_col_linear(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
36
37


38
39
40
41
def init_1d_row_embedding(weight, pg):
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
42
43


44
45
46
47
def init_1d_col_embedding(weight, pg):
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    weight.set_process_group(pg)
    weight.set_tensor_spec(*spec)
48
49
50


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
51
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
52
53
54
55
56
57
58
59
60
61
62
    for name, p in model.named_parameters():
        if not isinstance(p, ColoTensor):
            continue
        if 'embed' in name and 'weight' in name:
            init_1d_col_embedding(p, pg)
        if 'proj1' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
        if 'proj2' in name and 'weight' in name:
            init_1d_row_linear(p, pg)
        if 'classifier' in name and ('weight' in name or 'bias' in name):
            init_1d_col_linear(p, pg)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


def check_param_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


80
81
82
def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
83

84
85
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
86

87
    # set_seed(1)
88
    with ColoInitContext(device=get_current_device()):
89
90
        model = model_builder(checkpoint=True)
        model_reload = model_builder(checkpoint=True)
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    if use_mp_reload:
        if 'bert' == model_name:
            for name, p in model.named_parameters():
                if not isinstance(p, ColoTensor):
                    continue
                # num_class = type_vocab_size = 2 | (8, 2)
                if 'classifier' in name and 'weight' in name:
                    init_1d_row_linear(p, pg)
                # num_class = vocab_size = 30524 | (30524, 8)
                elif 'word_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = seq_len = 512 | (512, 8)
                elif 'position_embeddings' in name and 'weight' in name:
                    init_1d_row_embedding(p, pg)
                # num_class = type_vocab_size = 2 | (2, 8)
                elif 'token_type_embeddings' in name and 'weight' in name:
                    init_1d_col_embedding(p, pg)
                elif p.process_group.tp_world_size() == 1:
                    p.redistribute(ReplicaSpec(), pg)
        elif "simple_net" == model_name:
            init_spec_func(model, pg)
113

114
    model = model.cuda()
115
116
    model.train()

117
    model_reload = model_reload.cuda()
118
    model_reload.train()
119

120
    colo_optimizer = ColoOptimizer(model.named_parameters(), torch.optim.SGD, lr=0.1)
121
122

    for i, (data, label) in enumerate(train_dataloader):
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        # Zero grad
        colo_optimizer.zero_grad()

        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 2:
            break

    if not os.path.isdir('./checkpoint') and rank == 0:
        os.mkdir('./checkpoint')
    save_checkpoint('./checkpoint', 0, model, None, None)
147
    dist.barrier()
148
    load_checkpoint('./checkpoint', 0, model_reload, None, None)
149
150
151
152
153
154
155
156
157
158

    # Since model is sharded, we merge them before param checking.
    for p in model.parameters():
        p.to_replicate_()

    for p in model_reload.parameters():
        p.to_replicate_()

    check_param_equal(model, model_reload)

159
160
161
    if rank == 0:
        remove('./checkpoint')

162
163

def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
164
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
165
    pg = ProcessGroup(tp_degree=world_size)
166
167
168
169
170
171
172
    for model_name in ['bert', 'simple_net']:
        _run_checkpoint(model_name,
                        init_1d_row_for_linear_weight_spec,
                        use_ddp,
                        use_mp_reload,
                        test_scheduler=test_scheduler,
                        pg=pg)
173
174
175


@pytest.mark.dist
176
@pytest.mark.parametrize('world_size', [1, 2])
177
@pytest.mark.parametrize('use_ddp', [False])
178
@pytest.mark.parametrize('use_mp_reload', [True, False])
179
# @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
180
@rerun_if_address_is_in_use()
181
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None):
182
183
184
185
    run_func = partial(run_dist,
                       world_size=world_size,
                       port=free_port(),
                       use_ddp=use_ddp,
186
                       use_mp_reload=use_mp_reload,
187
                       test_scheduler=test_scheduler)
188
189
190
191
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
192
    test_checkpoint(2, use_ddp=False, use_mp_reload=True, test_scheduler="torch_cosine")