test_gpt2.py 5.6 KB
Newer Older
ver217's avatar
ver217 committed
1
import pytest
2

3
from functools import partial
4
from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, set_seed
5
6
7

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
ver217's avatar
ver217 committed
8
import torch.multiprocessing as mp
9
10

import colossalai
ver217's avatar
ver217 committed
11
12
13
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
14
from colossalai.utils.model.colo_init_context import ColoInitContext
15
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor, ColoTensorSpec
16
from colossalai.nn.parallel.data_parallel import ColoDDP
17
from tests.components_to_test.registry import non_distributed_component_funcs
18
from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, debug_print
ver217's avatar
ver217 committed
19
20


21
def init_1d_row_spec(model, pg: ProcessGroup):
22
    tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
23
24
25
26
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'weight' in n and 'ln' not in n:
            p.set_tensor_spec(*tensor_spec)
ver217's avatar
ver217 committed
27
28


29
def init_1d_col_spec(model, pg: ProcessGroup):
30
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
31
32
33
34
35

    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'ln' not in n and ('weight' in n or 'bias' in n):
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
36
37


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def init_megatron_spec(model, pg: ProcessGroup):
    for mn, module in model.named_modules():
        # debug_print([0], mn)
        for pn, param in module.named_parameters(recurse=False):
            # debug_print([0], '\t', pn, param.compute_spec, param.shape)
            param.set_process_group(pg)

            if 'mlp.c_fc' in mn:
                if 'weight' in pn or 'bias' in pn:
                    split_param_col_tp1d(param, pg)
                    param.compute_spec.set_output_replicate(False)
                else:
                    raise RuntimeError
            elif 'mlp.c_proj' in mn:
                if 'weight' in pn:
                    split_param_row_tp1d(param, pg)
                else:
                    assert 'bias' in pn
            elif 'wte' in mn or 'wpe' in mn:
                assert 'weight' in pn
                split_param_col_tp1d(param, pg)
            elif 'c_fc' in mn or 'c_proj' in mn:
                split_param_col_tp1d(param, pg)
            # debug_print([0], '\t', param.compute_spec, param.shape)


64
def check_param_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
65
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
66
67
68
        assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
        assert pg.tp_world_size() is not None
        assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
69
70


71
def check_grad_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
72
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
73
        assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
74
75


76
def run_gpt(init_spec_func, use_ddp):
77
    world_size = torch.distributed.get_world_size()
78
79

    # build a PG with TP and DP hybrid
80
    pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
81
82
83
84

    # set seed make processes of the same tp group use the same seed
    # set_seed(pg.tp_local_rank())

85
86
87
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

88
    # make sure torch_model and model has the same parameter values
ver217's avatar
ver217 committed
89
    with ColoInitContext(device=get_current_device()):
90
        model = model_builder()
ver217's avatar
ver217 committed
91
    model = model.cuda()
92
    torch_model = model_builder().cuda()
93

94
95
    if use_ddp:
        torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
96
        model = ColoDDP(model, process_group=pg)
97

ver217's avatar
ver217 committed
98
99
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)
100

101
    init_spec_func(model, pg)
102

103
104
    check_param_equal(model, torch_model, pg)

105
106
107
108
109
    # close the dropout in eval mode
    model.eval()
    torch_model.eval()
    set_seed(pg.dp_local_rank())
    torch.distributed.barrier()
110
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
111
112
        colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
        logits = model(colo_input, attn_mask)
ver217's avatar
ver217 committed
113
        torch_logits = torch_model(input_ids, attn_mask)
114
        assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
ver217's avatar
ver217 committed
115
116
        loss = criterion(logits, input_ids)
        torch_loss = criterion(torch_logits, input_ids)
117
118
119
120
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
ver217's avatar
ver217 committed
121
        torch_loss.backward()
122
        check_grad_equal(model, torch_model, pg)
123
124
        if i > 0:
            break
125
    set_seed(313)
ver217's avatar
ver217 committed
126
127


128
129
130
def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
131
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
132
133
134
135
    # Comments below tests for speed concern
    # run_gpt(init_1d_row_spec, use_ddp)
    # run_gpt(init_1d_col_spec, use_ddp)
    run_gpt(init_megatron_spec, use_ddp)
ver217's avatar
ver217 committed
136
137
138
139


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
140
@pytest.mark.parametrize('use_ddp', [False, True])
ver217's avatar
ver217 committed
141
@rerun_if_address_is_in_use()
142
143
def test_gpt(world_size, use_ddp):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
ver217's avatar
ver217 committed
144
145
146
147
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
148
    test_gpt(4, use_ddp=False)