test_model.py 11.3 KB
Newer Older
1
from functools import partial
2
3

import pytest
4
import torch
5
import torch.multiprocessing as mp
6
7

import colossalai
8
9
10
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
11
from colossalai.testing import rerun_if_address_is_in_use
12
from colossalai.utils import free_port
13
from colossalai.utils.cuda import get_current_device
14
from colossalai.utils.model.colo_init_context import ColoInitContext
15
from tests.components_to_test.registry import non_distributed_component_funcs
16
17
18
19
20
21
22
from tests.test_tensor.common_utils import (
    check_equal,
    set_seed,
    split_param_col_tp1d,
    split_param_row_tp1d,
    tensor_shard_equal,
)
23

24

Ziyue Jiang's avatar
Ziyue Jiang committed
25
def run_1d_hybrid_tp(model_name):
26
    # A simple net with two stacked nn.Linear
27
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
28
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
29

30
    rank = torch.distributed.get_rank()
31
    world_size = torch.distributed.get_world_size()
32
33
34
35

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)
36

37
38
39
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
40
41

        optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
42
43
44
45

        # Make two models have the same init params
        for p1, p2 in zip(model.parameters(), model_torch.parameters()):
            p2.data.copy_(p1.data)
46
47
48
    else:
        model_torch = None
        optimizer_torch = None
49

50
    pg = ProcessGroup(tp_degree=world_size)
51
    if 'bert' == model_name:
ver217's avatar
ver217 committed
52
        for name, p in model.named_parameters():
53
54
            if not isinstance(p, ColoTensor):
                continue
55

Ziyue Jiang's avatar
Ziyue Jiang committed
56
            # num_class = type_vocab_size = 2 | (8, 2)
57
            if 'classifier' in name and 'weight' in name:
58
                split_param_col_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
59
            # num_class = vocab_size = 30524 | (30524, 8)
60
            elif 'word_embeddings' in name and 'weight' in name:
61
                split_param_row_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
62
            # num_class = seq_len = 512 | (512, 8)
63
            elif 'position_embeddings' in name and 'weight' in name:
64
                split_param_row_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
65
            # num_class = type_vocab_size = 2 | (2, 8)
66
            elif 'token_type_embeddings' in name and 'weight' in name:
67
                split_param_col_tp1d(p, pg)
68

69
70
    elif "simple_net" == model_name:
        # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
71
        for name, p in model.named_parameters():
72
73
            if not isinstance(p, ColoTensor):
                continue
74
            if 'embed' in name and 'weight' in name:
75
                split_param_col_tp1d(p, pg)
76
            if 'proj1' in name and ('weight' in name or 'bias' in name):
77
                split_param_row_tp1d(p, pg)
78
            if 'proj2' in name and 'weight' in name:
79
                split_param_col_tp1d(p, pg)
80
            if 'classifier' in name and ('weight' in name or 'bias' in name):
81
                split_param_row_tp1d(p, pg)
82

83
    model = model.cuda()
84
    model.eval()
85
    if rank == 0:
86
        model_torch.eval()
87

88
    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
89

90
    for i, (data, label) in enumerate(train_dataloader):
91
92

        # Zero grad
93
94
        colo_optimizer.zero_grad()
        if rank == 0:
95
            optimizer_torch.zero_grad()
96
        torch.distributed.barrier()
97

98
99
100
        data = data.to(get_current_device())
        label = label.to(get_current_device())

101
102
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
103

104
105
106
107
108
109
110
111
        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

112
        # Test output
113
114
115
116
117
118
119
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch
120
            assert torch.allclose(loss, loss_torch, rtol=1e-2)
121
        torch.distributed.barrier()
122
123

        loss.backward()
124
        colo_optimizer.step()
125
126
127

        if rank == 0:
            loss_torch.backward()
128
            optimizer_torch.step()
129
130
131

            with torch.no_grad():
                # check param
132
                for p, torch_p in zip(model.parameters(), model_torch.parameters()):
133
                    assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
134
        torch.distributed.barrier()
135
136
        if i > 5:
            break
137

138

139
140
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
141
142
    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    # build a module with 2 Linear, 4 parameters in total.
    class Net(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
            self.extra_param = torch.nn.Parameter(torch.randn(2))

    with ColoInitContext(device=get_current_device()):
        model = Net()

    param_cnt = 0
    for name, p in model.named_parameters():
        param_cnt += 1
    assert param_cnt == 5

ver217's avatar
ver217 committed
159
    for name, colo_p in model.named_parameters():
160
161
        assert colo_p.is_model_data()

162
163
164
165
166
167
168
169
170
171
172
    param_cnt = 0
    for name, p in model.named_parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 1

    param_cnt = 0
    for p in model.fcs[0].parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 2


173
174
175
176
def test_colo_optimizer():
    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    set_seed(1)
177
    with ColoInitContext(device=get_current_device()):
178
179
        model = model_builder(checkpoint=True)

180
    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    for i, (data, label) in enumerate(train_dataloader):
        colo_optimizer.zero_grad()
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 5:
            break


201
def run_1d_row_tp(model_name: str):
202
    # A simple net with two stacked nn.Linear
203
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
204
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
205
    rank = torch.distributed.get_rank()
206

207
    set_seed(1)
208
    with ColoInitContext(device=get_current_device()):
209
210
        model = model_builder(checkpoint=True)

211
    world_size = torch.distributed.get_world_size()
212
    pg = ProcessGroup(tp_degree=world_size)
213

214
215
216
217
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
218

219
    # A naive way to set spec for all weights in Linear
220
221
222
223
224
225
226
227
228
229
230
    for mo_name, module in model.named_modules():
        # print(mo_name)
        for pa_name, param in module.named_parameters(recurse=False):
            # print('\t', pa_name, param.shape)
            if not isinstance(param, ColoTensor):
                continue
            if 'weight' in pa_name:
                if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
                    split_param_row_tp1d(param, pg)
                elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
                    split_param_col_tp1d(param, pg)
231

232
    model = model.cuda()
233
234

    for i, (data, label) in enumerate(train_dataloader):
235
236
        data = data.to(get_current_device())
        label = label.to(get_current_device())
237

238
239
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
240
241

        # Bcast rank0 data to all processes
242
        if criterion:
243
            output = model(data)
244
245
            loss = criterion(output, label)
        else:
246
            output = model(data, label)
247
248
            loss = output

249
250
251
252
253
254
255
256
        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch
ver217's avatar
ver217 committed
257
            assert torch.allclose(loss, loss_torch, rtol=1e-2)
258
        torch.distributed.barrier()
259

260
261
        loss.backward()

262
263
        if rank == 0:
            loss_torch.backward()
264
265
        torch.distributed.barrier()

266
267
268
269
        if i > 5:
            break


270
271
272
273
def _run_pretrain_load():
    from transformers import BertForMaskedLM
    set_seed(1)
    model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
274
    with ColoInitContext(device=get_current_device()):
275
276
277
278
279
280
281
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    model_pretrained = model_pretrained.cuda()
    model = model.cuda()

    dict_pretrained = {}
    dict_col = {}
Ziyue Jiang's avatar
Ziyue Jiang committed
282
    c_ref = 0
283
284
    for name, param in model_pretrained.named_parameters():
        dict_pretrained[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
285
        c_ref += 1
286
287
    c1 = 0
    c2 = 0
ver217's avatar
ver217 committed
288
    for name, param in model.named_parameters():
289
        if isinstance(param, ColoParameter):
Ziyue Jiang's avatar
Ziyue Jiang committed
290
            c1 += 1
291
        else:
292
            c2 += 1
293
        dict_col[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
294
295
296
297
    assert c_ref == c1
    assert c2 == 0
    if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
        assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
298
299
300
301
302
303

    for name, param in dict_pretrained.items():
        check_equal(param, dict_col[name])


def run_model_dist(rank, world_size, port):
304
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
305
306
307
    # Comment below test for speed consideration
    # for name in ['bert', 'simple_net']:
    #     run_1d_row_tp(name)
308
    for name in ['bert', 'simple_net']:
Ziyue Jiang's avatar
Ziyue Jiang committed
309
        run_1d_hybrid_tp(name)
310

311

312
@pytest.mark.dist
Ziyue Jiang's avatar
Ziyue Jiang committed
313
@pytest.mark.parametrize('world_size', [1, 4])
314
@rerun_if_address_is_in_use()
315
def test_model(world_size):
316
317
318
319
320
    run_func = partial(run_model_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


def run_pretrain_load_dist(rank, world_size, port):
321
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
322
323
324
325
326
    _run_pretrain_load()


# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
327
@pytest.mark.skip
328
329
330
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
331
def test_pretrain_load(world_size):
332
    run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
333
334
    mp.spawn(run_func, nprocs=world_size)

335
336

if __name__ == '__main__':
337
    # test_model_parameters()
338
    # test_colo_optgimizer()
339
340
    test_model(4)
    # test_pretrain_load(4)