test_model.py 12 KB
Newer Older
1
import pytest
2
3
4
from functools import partial
from _utils import tensor_shard_equal, set_seed

5
import torch
6
import torch.multiprocessing as mp
7
8
9

from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
10
from colossalai.testing import rerun_if_address_is_in_use
11
12
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
13
from colossalai.utils.model.colo_init_context import ColoInitContext
14
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
15
    ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
16
from colossalai.nn.optimizer import ColoOptimizer
17
18

from tests.components_to_test.registry import non_distributed_component_funcs
19

20

21
22
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
23
    with DistSpecManager.no_grad():
24
25
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
26

27

28
def init_1d_col_linear(weight, pg):
29
    spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
30
    with DistSpecManager.no_grad():
31
32
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
33

34

35
def init_1d_row_embedding(weight, pg):
36
    spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
37
    with DistSpecManager.no_grad():
38
39
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
40

41

42
def init_1d_col_embedding(weight, pg):
43
    spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
44
    with DistSpecManager.no_grad():
45
46
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
47

48

Ziyue Jiang's avatar
Ziyue Jiang committed
49
def run_1d_hybrid_tp(model_name):
50
    # A simple net with two stacked nn.Linear
51
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
52
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
53
    rank = torch.distributed.get_rank()
54
55
56
57

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)
58

59
60
61
62
63
64
65
66
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
        colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)

        # Make two models have the same init params
        for p1, p2 in zip(model.parameters(), model_torch.parameters()):
            p2.data.copy_(p1.data)
67

68
69
70
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
71
    if 'bert' == model_name:
ver217's avatar
ver217 committed
72
        for name, p in model.named_parameters():
73
74
            if not isinstance(p, ColoTensor):
                continue
75
            # print(name)
Ziyue Jiang's avatar
Ziyue Jiang committed
76
            # num_class = type_vocab_size = 2 | (8, 2)
77
78
79
            # TODO(jiaruifang) has bug if open the following 2 comments
            # if 'classifier' in name and 'weight' in name:
            #     init_1d_row_linear(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
80
81
            # num_class = vocab_size = 30524 | (30524, 8)
            if 'word_embeddings' in name and 'weight' in name:
82
                init_1d_row_embedding(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
83
84
            # num_class = seq_len = 512 | (512, 8)
            if 'position_embeddings' in name and 'weight' in name:
85
                init_1d_row_embedding(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
86
87
            # num_class = type_vocab_size = 2 | (2, 8)
            if 'token_type_embeddings' in name and 'weight' in name:
88
                init_1d_col_embedding(p, pg)
89
90
    elif "simple_net" == model_name:
        # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
91
        for name, p in model.named_parameters():
92
93
            if not isinstance(p, ColoTensor):
                continue
94
            if 'embed' in name and 'weight' in name:
95
                init_1d_col_embedding(p, pg)
96
            if 'proj1' in name and ('weight' in name or 'bias' in name):
97
                init_1d_col_linear(p, pg)
98
            if 'proj2' in name and 'weight' in name:
99
                init_1d_row_linear(p, pg)
100
            if 'classifier' in name and ('weight' in name or 'bias' in name):
101
                init_1d_col_linear(p, pg)
102

103
    model = model.cuda()
104
    colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
105
    for i, (data, label) in enumerate(train_dataloader):
106
107
108
109
110
        model.eval()
        colo_optimizer.zero_grad()
        if rank == 0:
            model_torch.eval()
            colo_optimizer_torch.zero_grad()
111

112
113
114
        data = data.to(get_current_device())
        label = label.to(get_current_device())

115
116
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
136
            with torch.no_grad():
ver217's avatar
ver217 committed
137
                assert torch.allclose(loss, loss_torch, rtol=1e-2)
138
139

        loss.backward()
140
        colo_optimizer.step()
141
142
143

        if rank == 0:
            loss_torch.backward()
144
145
146
147
            colo_optimizer_torch.step()

            with torch.no_grad():
                # check param
148
                for p, torch_p in zip(model.parameters(), model_torch.parameters()):
149
                    assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
150

151
152
        if i > 5:
            break
153

154

155
156
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
157
158
    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    # build a module with 2 Linear, 4 parameters in total.
    class Net(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
            self.extra_param = torch.nn.Parameter(torch.randn(2))

    with ColoInitContext(device=get_current_device()):
        model = Net()

    param_cnt = 0
    for name, p in model.named_parameters():
        param_cnt += 1
    assert param_cnt == 5

ver217's avatar
ver217 committed
175
    for name, colo_p in model.named_parameters():
176
177
        assert colo_p.is_model_data()

178
179
180
181
182
183
184
185
186
187
188
    param_cnt = 0
    for name, p in model.named_parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 1

    param_cnt = 0
    for p in model.fcs[0].parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 2


189
# @pytest.mark.skip
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def test_colo_optimizer():
    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    set_seed(1)
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = model_builder(checkpoint=True)

    colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
    for i, (data, label) in enumerate(train_dataloader):
        colo_optimizer.zero_grad()
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 5:
            break


218
def run_1d_row_tp(model_name: str):
219
    # A simple net with two stacked nn.Linear
220
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
221
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
222
    rank = torch.distributed.get_rank()
223

224
    set_seed(1)
225
    with ColoInitContext(device=get_current_device()):
226
227
        model = model_builder(checkpoint=True)

228
    world_size = torch.distributed.get_world_size()
229
    pg = ProcessGroup(tp_degree=world_size)
230

231
232
233
234
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
235
    # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
236
    for name, p in model.named_parameters():
237
238
239
        if not isinstance(p, ColoTensor):
            continue
        if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
240
            init_1d_row_linear(p, pg)
241
        if 'embed' in name and 'weight' in name:
242
            init_1d_row_embedding(p, pg)
243

244
    model = model.cuda()
245
246

    for i, (data, label) in enumerate(train_dataloader):
247
248
        data = data.to(get_current_device())
        label = label.to(get_current_device())
249

250
251
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
252
253

        # Bcast rank0 data to all processes
254
        if criterion:
255
            output = model(data)
256
257
            loss = criterion(output, label)
        else:
258
            output = model(data, label)
259
260
            loss = output

261
262
263
264
265
266
267
268
269
270
        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
ver217's avatar
ver217 committed
271
            assert torch.allclose(loss, loss_torch, rtol=1e-2)
272

273
274
        loss.backward()

275
276
        if rank == 0:
            loss_torch.backward()
277
278
279
280
        if i > 5:
            break


281
282
283
284
285
286
287
288
289
290
291
292
293
def _run_pretrain_load():
    from _utils import check_equal
    from transformers import BertForMaskedLM
    set_seed(1)
    model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    model_pretrained = model_pretrained.cuda()
    model = model.cuda()

    dict_pretrained = {}
    dict_col = {}
Ziyue Jiang's avatar
Ziyue Jiang committed
294
    c_ref = 0
295
296
    for name, param in model_pretrained.named_parameters():
        dict_pretrained[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
297
        c_ref += 1
298
299
    c1 = 0
    c2 = 0
ver217's avatar
ver217 committed
300
    for name, param in model.named_parameters():
301
        if isinstance(param, ColoParameter):
Ziyue Jiang's avatar
Ziyue Jiang committed
302
            c1 += 1
303
        else:
304
            c2 += 1
305
        dict_col[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
306
307
308
309
    assert c_ref == c1
    assert c2 == 0
    if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
        assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
310
311
312
313
314
315

    for name, param in dict_pretrained.items():
        check_equal(param, dict_col[name])


def run_model_dist(rank, world_size, port):
316
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Ziyue Jiang's avatar
Ziyue Jiang committed
317
    for name in ['simple_net']:
318
        run_1d_row_tp(name)
319
    for name in ['bert', 'simple_net']:
Ziyue Jiang's avatar
Ziyue Jiang committed
320
        run_1d_hybrid_tp(name)
321

322

323
@pytest.mark.dist
Ziyue Jiang's avatar
Ziyue Jiang committed
324
@pytest.mark.parametrize('world_size', [1, 4])
325
@rerun_if_address_is_in_use()
326
def test_model(world_size):
327
328
329
330
331
    run_func = partial(run_model_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


def run_pretrain_load_dist(rank, world_size, port):
332
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
333
334
335
336
337
    _run_pretrain_load()


# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
338
@pytest.mark.skip
339
340
341
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
342
def test_pretrain_load(world_size):
343
    run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
344
345
    mp.spawn(run_func, nprocs=world_size)

346
347

if __name__ == '__main__':
348
    # test_model_parameters()
349
    # test_colo_optimizer()
350
351
    test_model(4)
    # test_pretrain_load(4)