test_model.py 15.1 KB
Newer Older
1
from colossalai.tensor.colo_parameter import ColoParameter
2
3
4
5
from tests.components_to_test.registry import non_distributed_component_funcs

import colossalai
import pytest
6
import torch
7
8
9
10
11
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
12
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer
13
from colossalai.context import ParallelMode
14
from colossalai.core import global_context as gpc
15
16

from functools import partial
17
18
19
import random
import os
import numpy as np
20

21
22
23
24
# Hack huggingface Bert ModelOutput
# Make it available to our ColoTensor
from transformers.file_utils import ModelOutput
from dataclasses import fields
25
26


27
def _post_init_colotensor(self):
28
29
    class_fields = fields(self)
    # Safety and consistency checks
30
    if len(class_fields) == 0:
31
32
33
34
35
36
37
38
39
40
41
42
43
        raise ValueError(f"{self.__class__.__name__} has no fields.")
    if not all(field.default is None for field in class_fields[1:]):
        raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")

    first_field = getattr(self, class_fields[0].name)
    other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])

    def is_tensor_with_colo(x):
        """
        Tests if `x` is a `ColoTensor` or `torch.Tensor`.
        """
        if isinstance(x, torch.Tensor):
            return True
44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        return isinstance(x, ColoTensor)

    if other_fields_are_none and not is_tensor_with_colo(first_field):
        if isinstance(first_field, dict):
            iterator = first_field.items()
            first_field_iterator = True
        else:
            try:
                iterator = iter(first_field)
                first_field_iterator = True
            except TypeError:
                first_field_iterator = False

        # if we provided an iterator as first field and the iterator is a (key, value) iterator
        # set the associated fields
        if first_field_iterator:
            for element in iterator:
62
                if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)):
63
64
65
66
67
68
69
70
71
72
73
74
                    break
                setattr(self, element[0], element[1])
                if element[1] is not None:
                    self[element[0]] = element[1]
        elif first_field is not None:
            self[class_fields[0].name] = first_field
    else:
        for field in class_fields:
            v = getattr(self, field.name)
            if v is not None:
                self[field.name] = v

75

76
ModelOutput.__post_init__ = _post_init_colotensor
77
# complete the hack
78

79

80
81
82
83
84
85
86
87
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

88

Ziyue Jiang's avatar
Ziyue Jiang committed
89
def run_1d_hybrid_tp(model_name):
90
    # A simple net with two stacked nn.Linear
91
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
92
93
94
95
96
97
98
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
        model = model_builder(checkpoint=True)

99
    if 'bert' == model_name:
Ziyue Jiang's avatar
Ziyue Jiang committed
100
        parallel_action_list_row = [
101
            ParallelAction(priority=1,
Ziyue Jiang's avatar
Ziyue Jiang committed
102
                           compute_pattern=ComputePattern.TP1DRow_Linear,
103
104
                           parallel_mode=ParallelMode.PARALLEL_1D)
        ]
Ziyue Jiang's avatar
Ziyue Jiang committed
105
        spec_linear_row = TensorSpec(parallel_action_list_row)
106
107
108
109
110
111
112
113

        parallel_action_list_embedding_col = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DCol_Embedding,
                           parallel_mode=ParallelMode.PARALLEL_1D)
        ]
        spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)

Ziyue Jiang's avatar
Ziyue Jiang committed
114
115
116
117
118
119
120
        parallel_action_list_embedding_row = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DRow_Embedding,
                           parallel_mode=ParallelMode.PARALLEL_1D)
        ]
        spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)

121
122
123
124
        for name, p in model.colo_named_parameters():
            if not isinstance(p, ColoTensor):
                continue
            #print(name)
Ziyue Jiang's avatar
Ziyue Jiang committed
125
126
127
128
129
130
131
132
133
134
135
            # num_class = type_vocab_size = 2 | (8, 2)
            if 'classifier' in name and 'weight' in name:
                p.set_spec(spec_linear_row)
            # num_class = vocab_size = 30524 | (30524, 8)
            if 'word_embeddings' in name and 'weight' in name:
                p.set_spec(spec_embedding_row)
            # num_class = seq_len = 512 | (512, 8)
            if 'position_embeddings' in name and 'weight' in name:
                p.set_spec(spec_embedding_row)
            # num_class = type_vocab_size = 2 | (2, 8)
            if 'token_type_embeddings' in name and 'weight' in name:
136
137
138
139
140
141
142
143
144
145
146
147
                p.set_spec(spec_embedding_col)
    elif "simple_net" == model_name:
        parallel_action_list_row = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DRow_Linear,
                           parallel_mode=ParallelMode.PARALLEL_1D)
        ]
        spec_row = TensorSpec(parallel_action_list_row)

        parallel_action_list_col = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DCol_Linear,
148
                           parallel_mode=ParallelMode.PARALLEL_1D),
149
150
151
        ]
        spec_col = TensorSpec(parallel_action_list_col)

152
153
154
155
156
157
158
159
        parallel_action_list_classifier_col = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DCol_Linear,
                           parallel_mode=ParallelMode.PARALLEL_1D,
                           gather_out=False),
        ]
        spec_classifier_col = TensorSpec(parallel_action_list_classifier_col)

160
161
162
163
164
165
166
167
168
169
        parallel_action_list_embedding_col = [
            ParallelAction(priority=1,
                           compute_pattern=ComputePattern.TP1DCol_Embedding,
                           parallel_mode=ParallelMode.PARALLEL_1D)
        ]
        spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
        # A naive way to set spec for all weights in Linear
        for name, p in model.colo_named_parameters():
            if not isinstance(p, ColoTensor):
                continue
170
171
            if 'embed' in name and 'weight' in name:
                p.set_spec(spec_embedding_col)
172
173
174
175
            if 'proj1' in name and ('weight' in name or 'bias' in name):
                p.set_spec(spec_col)
            if 'proj2' in name and 'weight' in name:
                p.set_spec(spec_row)
176
177
            if 'classifier' in name and ('weight' in name or 'bias' in name):
                p.set_spec(spec_classifier_col)
178

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()

    model = model.cuda()

    for i, (data, label) in enumerate(train_dataloader):
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
        torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
            # print(loss.torch_tensor().item())
            # print('loss torch', loss_torch.item())
            assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)

        loss.backward()

        if rank == 0:
            loss_torch.backward()
        if i > 5:
            break
221

222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Test the overrided parameters() and named_parameters() member functions
def test_model_parameters():
    # build a module with 2 Linear, 4 parameters in total.
    class Net(torch.nn.Module):

        def __init__(self):
            super().__init__()
            self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2))
            self.extra_param = torch.nn.Parameter(torch.randn(2))

    with ColoInitContext(device=get_current_device()):
        model = Net()

    param_cnt = 0
    for name, p in model.named_parameters():
        param_cnt += 1
    assert param_cnt == 5

241
242
243
    for name, colo_p in model.colo_named_parameters():
        assert colo_p.is_model_data()

244
245
246
247
248
249
250
251
252
253
254
    param_cnt = 0
    for name, p in model.named_parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 1

    param_cnt = 0
    for p in model.fcs[0].parameters(recurse=False):
        param_cnt += 1
    assert param_cnt == 2


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def test_colo_optimizer():
    get_components_func = non_distributed_component_funcs.get_callable('simple_net')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    set_seed(1)
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = model_builder(checkpoint=True)

    colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
    for i, (data, label) in enumerate(train_dataloader):
        colo_optimizer.zero_grad()
        data = data.to(get_current_device())
        label = label.to(get_current_device())

        # Bcast rank0 data to all processes
        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        loss.backward()
        colo_optimizer.step()

        if i > 5:
            break


283
def run_1d_row_tp(model_name: str):
284
    # A simple net with two stacked nn.Linear
285
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
286
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
287
    rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
288

289
    set_seed(1)
290
    with ColoInitContext(device=get_current_device()):
291
292
        model = model_builder(checkpoint=True)

293
294
295
296
297
    set_seed(1)
    if rank == 0:
        model_torch = model_builder(checkpoint=True)
        model_torch = model_torch.cuda()

298
    parallel_action_list = [
299
300
301
        ParallelAction(priority=1,
                       compute_pattern=ComputePattern.TP1DRow_Linear,
                       parallel_mode=ParallelMode.PARALLEL_1D)
302
303
304
    ]
    spec = TensorSpec(parallel_action_list)

305
    parallel_action_list_embedding_row = [
306
307
308
        ParallelAction(priority=1,
                       compute_pattern=ComputePattern.TP1DRow_Embedding,
                       parallel_mode=ParallelMode.PARALLEL_1D)
309
310
311
    ]
    spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)

312
    # A naive way to set spec for all weights in Linear
313
    for name, p in model.colo_named_parameters():
314
315
316
317
        if not isinstance(p, ColoTensor):
            continue
        if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
            p.set_spec(spec)
318
319
        if 'embed' in name and 'weight' in name:
            p.set_spec(spec_embedding_row)
320

321
    model = model.cuda()
322
323

    for i, (data, label) in enumerate(train_dataloader):
324
325
        data = data.to(get_current_device())
        label = label.to(get_current_device())
326

327
328
329
330
        torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
        torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))

        # Bcast rank0 data to all processes
331
        if criterion:
332
            output = model(data)
333
334
            loss = criterion(output, label)
        else:
335
            output = model(data, label)
336
337
            loss = output

338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # For reference
        if rank == 0:
            if criterion:
                output_torch = model_torch(data)
                loss_torch = criterion(output_torch, label)
            else:
                output_torch = model_torch(data, label)
                loss_torch = output_torch

        if rank == 0:
            # print(loss.torch_tensor().item())
            # print('loss torch', loss_torch.item())
            assert torch.allclose(loss.torch_tensor(), loss_torch, rtol=1e-2)

352
353
        loss.backward()

354
355
        if rank == 0:
            loss_torch.backward()
356
357
358
359
        if i > 5:
            break


360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def _run_pretrain_load():
    from _utils import check_equal
    from transformers import BertForMaskedLM
    set_seed(1)
    model_pretrained = BertForMaskedLM.from_pretrained('bert-base-uncased')
    with ColoInitContext(lazy_memory_allocate=False, device=get_current_device()):
        model = BertForMaskedLM.from_pretrained('bert-base-uncased')

    model_pretrained = model_pretrained.cuda()
    model = model.cuda()

    dict_pretrained = {}
    dict_col = {}
    for name, param in model_pretrained.named_parameters():
        dict_pretrained[name] = param
375
376
377
378
379
380
381
    c1 = 0
    c2 = 0
    for name, param in model.colo_named_parameters():
        if isinstance(param, ColoParameter):
            c1 = c1 + 1
        else:
            c2 = c2 + 1
382
383
384
385
386
387
388
        dict_col[name] = param

    for name, param in dict_pretrained.items():
        check_equal(param, dict_col[name])


def run_model_dist(rank, world_size, port):
389
390
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
Ziyue Jiang's avatar
Ziyue Jiang committed
391
    for name in ['simple_net']:
392
        run_1d_row_tp(name)
393
    for name in ['bert', 'simple_net']:
Ziyue Jiang's avatar
Ziyue Jiang committed
394
        run_1d_hybrid_tp(name)
395
396


397
@pytest.mark.dist
Ziyue Jiang's avatar
Ziyue Jiang committed
398
399
@pytest.mark.parametrize('world_size', [1, 4])
#@parameterize('world_size', [1, 4])
400
@rerun_if_address_is_in_use()
401
def test_model(world_size):
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    run_func = partial(run_model_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


def run_pretrain_load_dist(rank, world_size, port):
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    _run_pretrain_load()


# The test case has to download huggingface pretrained models from the internet
# So we manually trigger the test.
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def _test_pretrain_load(world_size):
    run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port())
419
420
    mp.spawn(run_func, nprocs=world_size)

421
422

if __name__ == '__main__':
423
    # test_model_parameters()
424
    # test_colo_optimizer()
425
    # test_model()
426
427
    # _test_pretrain_load(4)
    _run_pretrain_load()