"vscode:/vscode.git/clone" did not exist on "0385b26ebf4a811ca70eafe8590ad5e0529c0595"
test_model.py 11.1 KB
Newer Older
1
import pytest
2
import torch
3
4

import colossalai
5
6
7
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter
8
from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn
9
from colossalai.utils.cuda import get_current_device
10
from colossalai.zero import ColoInitContext
11
from tests.components_to_test.registry import non_distributed_component_funcs
12
13
14
15
16
17
18
from tests.test_tensor.common_utils import (
    check_equal,
    set_seed,
    split_param_col_tp1d,
    split_param_row_tp1d,
    tensor_shard_equal,
)
19

20

Ziyue Jiang's avatar
Ziyue Jiang committed
21
def run_1d_hybrid_tp(model_name):
22
    # A simple net with two stacked nn.Linear
23
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
24
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
25

26
    rank = torch.distributed.get_rank()
27
    world_size = torch.distributed.get_world_size()
28
29
30
31

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)
32

33
34
35
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
36
37

        optimizer_torch = ColossalaiOptimizer(torch.optim.SGD(model_torch.parameters(), lr=0.1))
38
39
40
41

        # Make two models have the same init params
        for p1, p2 in zip(model.parameters(), model_torch.parameters()):
            p2.data.copy_(p1.data)
42
43
44
    else:
        model_torch = None
        optimizer_torch = None
45

46
    pg = ProcessGroup(tp_degree=world_size)
47
    if 'bert' == model_name:
ver217's avatar
ver217 committed
48
        for name, p in model.named_parameters():
49
50
            if not isinstance(p, ColoTensor):
                continue
51

Ziyue Jiang's avatar
Ziyue Jiang committed
52
            # num_class = type_vocab_size = 2 | (8, 2)
53
            if 'classifier' in name and 'weight' in name:
54
                split_param_col_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
55
            # num_class = vocab_size = 30524 | (30524, 8)
56
            elif 'word_embeddings' in name and 'weight' in name:
57
                split_param_row_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
58
            # num_class = seq_len = 512 | (512, 8)
59
            elif 'position_embeddings' in name and 'weight' in name:
60
                split_param_row_tp1d(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
61
            # num_class = type_vocab_size = 2 | (2, 8)
62
            elif 'token_type_embeddings' in name and 'weight' in name:
63
                split_param_col_tp1d(p, pg)
64

65
66
    elif "simple_net" == model_name:
        # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
67
        for name, p in model.named_parameters():
68
69
            if not isinstance(p, ColoTensor):
                continue
70
            if 'embed' in name and 'weight' in name:
71
                split_param_col_tp1d(p, pg)
72
            if 'proj1' in name and ('weight' in name or 'bias' in name):
73
                split_param_row_tp1d(p, pg)
74
            if 'proj2' in name and 'weight' in name:
75
                split_param_col_tp1d(p, pg)
76
            if 'classifier' in name and ('weight' in name or 'bias' in name):
77
                split_param_row_tp1d(p, pg)
78

79
    model = model.cuda()
80
    model.eval()
81
    if rank == 0:
82
        model_torch.eval()
83

84
    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
85

86
    for i, (data, label) in enumerate(train_dataloader):
87
88

        # Zero grad
89
90
        colo_optimizer.zero_grad()
        if rank == 0:
91
            optimizer_torch.zero_grad()
92
        torch.distributed.barrier()
93

94
95
96
        data = data.to(get_current_device())
        label = label.to(get_current_device())

97
98
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
99

100
101
102
103
104
105
106
107
        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

108
        # Test output
109
110
111
112
113
114
115
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch
116
            assert torch.allclose(loss, loss_torch, rtol=1e-2), f"model_name {model_name} failed"
117
        torch.distributed.barrier()
118
119

        loss.backward()
120
        colo_optimizer.step()
121
122
123

        if rank == 0:
            loss_torch.backward()
124
            optimizer_torch.step()
125
126
127

            with torch.no_grad():
                # check param
128
                for p, torch_p in zip(model.parameters(), model_torch.parameters()):
129
                    assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
130
        torch.distributed.barrier()
131
132
        if i > 5:
            break
133

134

135
136
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
137
138
    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    # build a module with 2 Linear, 4 parameters in total.
    class Net(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
            self.extra_param = torch.nn.Parameter(torch.randn(2))

    with ColoInitContext(device=get_current_device()):
        model = Net()

    param_cnt = 0
    for name, p in model.named_parameters():
        param_cnt += 1
    assert param_cnt == 5

ver217's avatar
ver217 committed
155
    for name, colo_p in model.named_parameters():
156
157
        assert colo_p.is_model_data()

158
159
160
161
162
163
164
165
166
167
168
    param_cnt = 0
    for name, p in model.named_parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 1

    param_cnt = 0
    for p in model.fcs[0].parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 2


169
170
171
172
def test_colo_optimizer():
    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    set_seed(1)
173
    with ColoInitContext(device=get_current_device()):
174
175
        model = model_builder(checkpoint=True)

176
    colo_optimizer = ColossalaiOptimizer(torch.optim.SGD(model.parameters(), lr=0.1))
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    for i, (data, label) in enumerate(train_dataloader):
        colo_optimizer.zero_grad()
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 5:
            break


197
def run_1d_row_tp(model_name: str):
198
    # A simple net with two stacked nn.Linear
199
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
200
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
201
    rank = torch.distributed.get_rank()
202

203
    set_seed(1)
204
    with ColoInitContext(device=get_current_device()):
205
206
        model = model_builder(checkpoint=True)

207
    world_size = torch.distributed.get_world_size()
208
    pg = ProcessGroup(tp_degree=world_size)
209

210
211
212
213
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
214

215
    # A naive way to set spec for all weights in Linear
216
217
218
219
220
221
222
223
224
225
226
    for mo_name, module in model.named_modules():
        # print(mo_name)
        for pa_name, param in module.named_parameters(recurse=False):
            # print('\t', pa_name, param.shape)
            if not isinstance(param, ColoTensor):
                continue
            if 'weight' in pa_name:
                if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
                    split_param_row_tp1d(param, pg)
                elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
                    split_param_col_tp1d(param, pg)
227

228
    model = model.cuda()
229
230

    for i, (data, label) in enumerate(train_dataloader):
231
232
        data = data.to(get_current_device())
        label = label.to(get_current_device())
233

234
235
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
236
237

        # Bcast rank0 data to all processes
238
        if criterion:
239
            output = model(data)
240
241
            loss = criterion(output, label)
        else:
242
            output = model(data, label)
243
244
            loss = output

245
246
247
248
249
250
251
252
        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch
ver217's avatar
ver217 committed
253
            assert torch.allclose(loss, loss_torch, rtol=1e-2)
254
        torch.distributed.barrier()
255

256
257
        loss.backward()

258
259
        if rank == 0:
            loss_torch.backward()
260
261
        torch.distributed.barrier()

262
263
264
265
        if i > 5:
            break


266
267
268
269
def _run_pretrain_load():
    from transformers import BertForMaskedLM
    set_seed(1)
    model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
270
    with ColoInitContext(device=get_current_device()):
271
272
273
274
275
276
277
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    model_pretrained = model_pretrained.cuda()
    model = model.cuda()

    dict_pretrained = {}
    dict_col = {}
Ziyue Jiang's avatar
Ziyue Jiang committed
278
    c_ref = 0
279
280
    for name, param in model_pretrained.named_parameters():
        dict_pretrained[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
281
        c_ref += 1
282
283
    c1 = 0
    c2 = 0
ver217's avatar
ver217 committed
284
    for name, param in model.named_parameters():
285
        if isinstance(param, ColoParameter):
Ziyue Jiang's avatar
Ziyue Jiang committed
286
            c1 += 1
287
        else:
288
            c2 += 1
289
        dict_col[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
290
291
292
293
    assert c_ref == c1
    assert c2 == 0
    if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
        assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
294
295
296
297
298
299

    for name, param in dict_pretrained.items():
        check_equal(param, dict_col[name])


def run_model_dist(rank, world_size, port):
300
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
301
302
303
    # Comment below test for speed consideration
    # for name in ['bert', 'simple_net']:
    #     run_1d_row_tp(name)
304
    for name in ['bert', 'simple_net']:
Ziyue Jiang's avatar
Ziyue Jiang committed
305
        run_1d_hybrid_tp(name)
306

307

308
@pytest.mark.dist
Ziyue Jiang's avatar
Ziyue Jiang committed
309
@pytest.mark.parametrize('world_size', [1, 4])
310
@rerun_if_address_is_in_use()
311
def test_model(world_size):
312
    spawn(run_model_dist, world_size)
313
314
315


def run_pretrain_load_dist(rank, world_size, port):
316
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
317
318
319
320
321
    _run_pretrain_load()


# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
322
@pytest.mark.skip
323
324
325
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
326
def test_pretrain_load(world_size):
327
    spawn(run_pretrain_load_dist, world_size)
328

329
330

if __name__ == '__main__':
331
    # test_model_parameters()
332
    # test_colo_optgimizer()
333
334
    test_model(4)
    # test_pretrain_load(4)