test_colo_checkpoint.py 6.94 KB
Newer Older
1
from abc import ABC, abstractmethod
2
import os, shutil
3
4
5
import torch
import torch.nn as nn
import pytest
6
7
from functools import partial

8
9
import torch.multiprocessing as mp
import torch.distributed as dist
10
11
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiplicativeLR
12
13

import colossalai
14
15
16
17
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
18
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ShardSpec, ProcessGroup
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR


class DummyDataGenerator(ABC):

    def __init__(self, length=10):
        self.length = length

    @abstractmethod
    def generate(self):
        pass

    def __iter__(self):
        self.step = 0
        return self

    def __next__(self):
        if self.step < self.length:
            self.step += 1
            return self.generate()
        else:
            raise StopIteration

    def __len__(self):
        return self.length


class DummyDataLoader(DummyDataGenerator):
49
50
51
52
53
54

    def __init__(self, batch_size, category, feature_size, length=10):
        super().__init__(length)
        self.batch_size = batch_size
        self.category = category
        self.feature_size = feature_size
55
56
57

    def generate(self):
        image_dict = {}
58
59
        image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1
        image_dict['label'] = torch.randint(self.category, (self.batch_size,),
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                                            dtype=torch.int64,
                                            device=get_current_device())
        return image_dict


class MLP(nn.Module):

    def __init__(self, in_features, out_features, hidden_features=None):
        super().__init__()
        if hidden_features is None:
            hidden_features = out_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
83
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'weight' in n:
                p.set_process_group(pg)
                p.set_tensor_spec(*spec)


def check_param_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)


def remove(path):
    """ param <path> could either be relative or absolute. """
    if os.path.isfile(path) or os.path.islink(path):
        os.remove(path)
    elif os.path.isdir(path):
        shutil.rmtree(path)
    else:
        raise ValueError("file {} is not a file or dir.".format(path))


106
def run_checkpoint(init_spec_func, use_ddp, use_mp_reload, test_scheduler, pg):
107
108
109
    num_epoch = 5
    warmup_epoch = 2

110
111
112
    batch = 3
    feature = 32
    category = 16
113

114
    with ColoInitContext(device=get_current_device()):
115
        model = MLP(feature, category)
116
117

    with ColoInitContext(device=get_current_device()):
118
119
        model_reload = MLP(feature, category)

120
121
122
123
124
125
    model = model.cuda()
    model_reload = model_reload.cuda()
    if use_ddp:
        model = ColoDDP(model, pg)
        model_reload = ColoDDP(model_reload, pg)

126
    init_spec_func(model, pg)
127
128
    if use_mp_reload:
        init_spec_func(model_reload, pg)
129

130
131
132
133
134
135
136
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    optimizer_reload = torch.optim.Adam(model_reload.parameters(),
                                        lr=0.001,
                                        betas=(0.9, 0.999),
                                        eps=1e-08,
                                        weight_decay=0)

137
    lr_scheduler = None
138
139
140
141
142
143
144
145
146
147
148
149
    if test_scheduler == 'colossalai_cosine_warmup':
        lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=num_epoch, warmup_steps=warmup_epoch)
        lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload,
                                                      total_steps=num_epoch,
                                                      warmup_steps=warmup_epoch)
    elif test_scheduler == 'torch_cosine':
        lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=num_epoch)
        lr_scheduler_reload = CosineAnnealingLR(optimizer=optimizer_reload, T_max=num_epoch)
    elif test_scheduler == 'torch_lambda':
        lr_lambda = lambda epoch: 0.95
        lr_scheduler = MultiplicativeLR(optimizer=optimizer, lr_lambda=lr_lambda)
        lr_scheduler_reload = MultiplicativeLR(optimizer=optimizer_reload, lr_lambda=lr_lambda)
150
151
    else:
        raise TypeError(f"{test_scheduler} is invalid")
152

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    save_checkpoint('./checkpoint', 0, model, optimizer, lr_scheduler)
    dist.barrier()
    load_checkpoint('./checkpoint', 0, model_reload, optimizer_reload, lr_scheduler_reload)

    # Since model is sharded, we merge them before param checking.
    for p in model.parameters():
        p.to_replicate_()

    for p in model_reload.parameters():
        p.to_replicate_()

    check_param_equal(model, model_reload)


def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler):
168
169
170
171
172
173
    if use_ddp and world_size == 1:
        return
    tp_world_size = world_size // 2 if use_ddp else world_size
    config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    pg = ProcessGroup(tp_degree=world_size)
174
    run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, use_mp_reload, test_scheduler=test_scheduler, pg=pg)
175
176
177


@pytest.mark.dist
178
179
180
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('use_ddp', [True, False])
@pytest.mark.parametrize('use_mp_reload', [True, False])
181
@pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda'])
182
@rerun_if_address_is_in_use()
183
def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler):
184
185
    if not os.path.isdir('./checkpoint'):
        os.mkdir('./checkpoint')
186
187
188
189
    run_func = partial(run_dist,
                       world_size=world_size,
                       port=free_port(),
                       use_ddp=use_ddp,
190
                       use_mp_reload=use_mp_reload,
191
                       test_scheduler=test_scheduler)
192
193
194
195
196
    mp.spawn(run_func, nprocs=world_size)
    remove('./checkpoint')


if __name__ == '__main__':
197
    test_checkpoint(2, True, False, "torch_cosine")