test_model.py 12.1 KB
Newer Older
1
import pytest
2
3
4
from functools import partial
from _utils import tensor_shard_equal, set_seed

5
import torch
6
import torch.multiprocessing as mp
7
8
9

from colossalai.tensor.colo_parameter import ColoParameter
import colossalai
10
from colossalai.testing import rerun_if_address_is_in_use
11
12
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
13
from colossalai.utils.model.colo_init_context import ColoInitContext
14
from colossalai.tensor import distspec, ColoTensorSpec, ComputePattern, \
15
    ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
16
from colossalai.nn.optimizer import ColoOptimizer
17
18

from tests.components_to_test.registry import non_distributed_component_funcs
19

20

21
22
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
    spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
23
    with DistSpecManager.no_grad():
24
25
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
26

27

28
def init_1d_col_linear(weight, pg):
29
    spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
30
    with DistSpecManager.no_grad():
31
32
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
33

34

35
def init_1d_row_embedding(weight, pg):
36
    spec = (distspec.shard([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
37
    with DistSpecManager.no_grad():
38
39
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
40

41

42
def init_1d_col_embedding(weight, pg):
43
    spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
44
    with DistSpecManager.no_grad():
45
46
        weight.set_process_group(pg)
        weight.set_tensor_spec(*spec)
47

48

Ziyue Jiang's avatar
Ziyue Jiang committed
49
def run_1d_hybrid_tp(model_name):
50
    # A simple net with two stacked nn.Linear
51
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
52
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
53
    rank = torch.distributed.get_rank()
54
55
56
57

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)
58

59
60
61
62
63
64
65
66
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
        colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)

        # Make two models have the same init params
        for p1, p2 in zip(model.parameters(), model_torch.parameters()):
            p2.data.copy_(p1.data)
67

68
69
70
    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
71
    if 'bert' == model_name:
ver217's avatar
ver217 committed
72
        for name, p in model.named_parameters():
73
74
            if not isinstance(p, ColoTensor):
                continue
75
            # print(name)
Ziyue Jiang's avatar
Ziyue Jiang committed
76
77
            # num_class = type_vocab_size = 2 | (8, 2)
            if 'classifier' in name and 'weight' in name:
78
                init_1d_row_linear(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
79
80
            # num_class = vocab_size = 30524 | (30524, 8)
            if 'word_embeddings' in name and 'weight' in name:
81
                init_1d_row_embedding(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
82
83
            # num_class = seq_len = 512 | (512, 8)
            if 'position_embeddings' in name and 'weight' in name:
84
                init_1d_row_embedding(p, pg)
Ziyue Jiang's avatar
Ziyue Jiang committed
85
86
            # num_class = type_vocab_size = 2 | (2, 8)
            if 'token_type_embeddings' in name and 'weight' in name:
87
                init_1d_col_embedding(p, pg)
88
89
    elif "simple_net" == model_name:
        # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
90
        for name, p in model.named_parameters():
91
92
            if not isinstance(p, ColoTensor):
                continue
93
            if 'embed' in name and 'weight' in name:
94
                init_1d_col_embedding(p, pg)
95
            if 'proj1' in name and ('weight' in name or 'bias' in name):
96
                init_1d_col_linear(p, pg)
97
            if 'proj2' in name and 'weight' in name:
98
                init_1d_row_linear(p, pg)
99
            if 'classifier' in name and ('weight' in name or 'bias' in name):
100
                init_1d_col_linear(p, pg)
101

102
    model = model.cuda()
103
    colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
104
    for i, (data, label) in enumerate(train_dataloader):
105
106
107
108
109
        model.eval()
        colo_optimizer.zero_grad()
        if rank == 0:
            model_torch.eval()
            colo_optimizer_torch.zero_grad()
110

111
112
113
        data = data.to(get_current_device())
        label = label.to(get_current_device())

114
115
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
135
            with torch.no_grad():
ver217's avatar
ver217 committed
136
                assert torch.allclose(loss, loss_torch, rtol=1e-2)
137
138

        loss.backward()
139
        colo_optimizer.step()
140
141
142

        if rank == 0:
            loss_torch.backward()
143
144
145
146
            colo_optimizer_torch.step()

            with torch.no_grad():
                # check param
147
                for p, torch_p in zip(model.parameters(), model_torch.parameters()):
148
                    assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
149

150
151
        if i > 5:
            break
152

153

154
# Test the overrided parameters() and named_parameters() member functions
155
@pytest.mark.skip
156
def test_model_parameters():
157
158
    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    # build a module with 2 Linear, 4 parameters in total.
    class Net(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
            self.extra_param = torch.nn.Parameter(torch.randn(2))

    with ColoInitContext(device=get_current_device()):
        model = Net()

    param_cnt = 0
    for name, p in model.named_parameters():
        param_cnt += 1
    assert param_cnt == 5

ver217's avatar
ver217 committed
175
    for name, colo_p in model.named_parameters():
176
177
        assert colo_p.is_model_data()

178
179
180
181
182
183
184
185
186
187
188
    param_cnt = 0
    for name, p in model.named_parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 1

    param_cnt = 0
    for p in model.fcs[0].parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 2


189
@pytest.mark.skip
190
def test_colo_optimizer():
191
    colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    set_seed(1)
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = model_builder(checkpoint=True)

    colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
    for i, (data, label) in enumerate(train_dataloader):
        colo_optimizer.zero_grad()
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 5:
            break


219
def run_1d_row_tp(model_name: str):
220
    # A simple net with two stacked nn.Linear
221
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
222
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
223
    rank = torch.distributed.get_rank()
224

225
    set_seed(1)
226
    with ColoInitContext(device=get_current_device()):
227
228
        model = model_builder(checkpoint=True)

229
    world_size = torch.distributed.get_world_size()
230
    pg = ProcessGroup(tp_degree=world_size)
231

232
233
234
235
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()
236
    # A naive way to set spec for all weights in Linear
ver217's avatar
ver217 committed
237
    for name, p in model.named_parameters():
238
239
240
        if not isinstance(p, ColoTensor):
            continue
        if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
241
            init_1d_row_linear(p, pg)
242
        if 'embed' in name and 'weight' in name:
243
            init_1d_row_embedding(p, pg)
244

245
    model = model.cuda()
246
247

    for i, (data, label) in enumerate(train_dataloader):
248
249
        data = data.to(get_current_device())
        label = label.to(get_current_device())
250

251
252
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
253
254

        # Bcast rank0 data to all processes
255
        if criterion:
256
            output = model(data)
257
258
            loss = criterion(output, label)
        else:
259
            output = model(data, label)
260
261
            loss = output

262
263
264
265
266
267
268
269
270
271
        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
ver217's avatar
ver217 committed
272
            assert torch.allclose(loss, loss_torch, rtol=1e-2)
273

274
275
        loss.backward()

276
277
        if rank == 0:
            loss_torch.backward()
278
279
280
281
        if i > 5:
            break


282
283
284
285
286
287
288
289
290
291
292
293
294
def _run_pretrain_load():
    from _utils import check_equal
    from transformers import BertForMaskedLM
    set_seed(1)
    model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    model_pretrained = model_pretrained.cuda()
    model = model.cuda()

    dict_pretrained = {}
    dict_col = {}
Ziyue Jiang's avatar
Ziyue Jiang committed
295
    c_ref = 0
296
297
    for name, param in model_pretrained.named_parameters():
        dict_pretrained[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
298
        c_ref += 1
299
300
    c1 = 0
    c2 = 0
ver217's avatar
ver217 committed
301
    for name, param in model.named_parameters():
302
        if isinstance(param, ColoParameter):
Ziyue Jiang's avatar
Ziyue Jiang committed
303
            c1 += 1
304
        else:
305
            c2 += 1
306
        dict_col[name] = param
Ziyue Jiang's avatar
Ziyue Jiang committed
307
308
309
310
    assert c_ref == c1
    assert c2 == 0
    if model_pretrained.cls.predictions.decoder.bias is model_pretrained.cls.predictions.bias:
        assert model.cls.predictions.decoder.bias is model.cls.predictions.bias
311
312
313
314
315
316

    for name, param in dict_pretrained.items():
        check_equal(param, dict_col[name])


def run_model_dist(rank, world_size, port):
317
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Ziyue Jiang's avatar
Ziyue Jiang committed
318
    for name in ['simple_net']:
319
        run_1d_row_tp(name)
320
    for name in ['bert', 'simple_net']:
Ziyue Jiang's avatar
Ziyue Jiang committed
321
        run_1d_hybrid_tp(name)
322

323

324
@pytest.mark.dist
Ziyue Jiang's avatar
Ziyue Jiang committed
325
@pytest.mark.parametrize('world_size', [1, 4])
326
@pytest.mark.skip("under development")
327
@rerun_if_address_is_in_use()
328
def test_model(world_size):
329
330
331
332
333
    run_func = partial(run_model_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


def run_pretrain_load_dist(rank, world_size, port):
334
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
335
336
337
338
339
    _run_pretrain_load()


# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
340
@pytest.mark.skip
341
342
343
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
344
def test_pretrain_load(world_size):
345
    run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
346
347
    mp.spawn(run_func, nprocs=world_size)

348
349

if __name__ == '__main__':
350
    # test_model_parameters()
351
    # test_colo_optimizer()
352
353
    test_model(4)
    # test_pretrain_load(4)