test_gpt2.py 5.36 KB
Newer Older
1
import pytest
2
import torch
3
from torch.nn.parallel import DistributedDataParallel as DDP
4
5

import colossalai
6
7
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
8
from colossalai.testing import rerun_if_address_is_in_use, spawn
9
from colossalai.utils.cuda import get_current_device
10
from colossalai.zero import ColoInitContext
11
from tests.components_to_test.registry import non_distributed_component_funcs
12
13
14
15
16
17
18
19
from tests.test_tensor.common_utils import (
    debug_print,
    set_seed,
    split_param_col_tp1d,
    split_param_row_tp1d,
    tensor_equal,
    tensor_shard_equal,
)
ver217's avatar
ver217 committed
20
21


22
def init_1d_row_spec(model, pg: ProcessGroup):
23
    tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
24
25
26
27
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'weight' in n and 'ln' not in n:
            p.set_tensor_spec(*tensor_spec)
ver217's avatar
ver217 committed
28
29


30
def init_1d_col_spec(model, pg: ProcessGroup):
31
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
32
33
34
35
36

    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'ln' not in n and ('weight' in n or 'bias' in n):
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
37
38


39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def init_megatron_spec(model, pg: ProcessGroup):
    for mn, module in model.named_modules():
        # debug_print([0], mn)
        for pn, param in module.named_parameters(recurse=False):
            # debug_print([0], '\t', pn, param.compute_spec, param.shape)
            param.set_process_group(pg)

            if 'mlp.c_fc' in mn:
                if 'weight' in pn or 'bias' in pn:
                    split_param_col_tp1d(param, pg)
                    param.compute_spec.set_output_replicate(False)
                else:
                    raise RuntimeError
            elif 'mlp.c_proj' in mn:
                if 'weight' in pn:
                    split_param_row_tp1d(param, pg)
                else:
                    assert 'bias' in pn
            elif 'wte' in mn or 'wpe' in mn:
                assert 'weight' in pn
                split_param_col_tp1d(param, pg)
60
            elif 'c_attn' in mn or 'c_proj' in mn:
61
62
63
64
                split_param_col_tp1d(param, pg)
            # debug_print([0], '\t', param.compute_spec, param.shape)


65
def check_param_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
66
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
67
68
69
        assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
        assert pg.tp_world_size() is not None
        assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
70
71


72
def check_grad_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
73
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
74
        assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
75
76


77
def run_gpt(init_spec_func, use_ddp):
78
    world_size = torch.distributed.get_world_size()
79
80

    # build a PG with TP and DP hybrid
81
    pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
82
83
84
85

    # set seed make processes of the same tp group use the same seed
    # set_seed(pg.tp_local_rank())

86
87
88
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

89
    # make sure torch_model and model has the same parameter values
ver217's avatar
ver217 committed
90
    with ColoInitContext(device=get_current_device()):
91
        model = model_builder()
ver217's avatar
ver217 committed
92
    model = model.cuda()
93
    torch_model = model_builder().cuda()
94

95
96
    if use_ddp:
        torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
97
        model = ColoDDP(model, process_group=pg)
98

ver217's avatar
ver217 committed
99
100
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)
101

102
    init_spec_func(model, pg)
103

104
105
    check_param_equal(model, torch_model, pg)

106
107
108
109
110
    # close the dropout in eval mode
    model.eval()
    torch_model.eval()
    set_seed(pg.dp_local_rank())
    torch.distributed.barrier()
111
    for i, (input_ids, label) in enumerate(train_dataloader):
112
        colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
113
114
        logits = model(colo_input)
        torch_logits = torch_model(input_ids)
115
        assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
ver217's avatar
ver217 committed
116
117
        loss = criterion(logits, input_ids)
        torch_loss = criterion(torch_logits, input_ids)
118
119
120
121
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
ver217's avatar
ver217 committed
122
        torch_loss.backward()
123
        check_grad_equal(model, torch_model, pg)
124
125
        if i > 0:
            break
126
    set_seed(313)
ver217's avatar
ver217 committed
127
128


129
130
131
def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
132
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
133
134
135
136
    # Comments below tests for speed concern
    # run_gpt(init_1d_row_spec, use_ddp)
    # run_gpt(init_1d_col_spec, use_ddp)
    run_gpt(init_megatron_spec, use_ddp)
ver217's avatar
ver217 committed
137
138
139
140


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
141
@pytest.mark.parametrize('use_ddp', [False, True])
ver217's avatar
ver217 committed
142
@rerun_if_address_is_in_use()
143
def test_gpt(world_size, use_ddp):
144
    spawn(run_dist, world_size, use_ddp=use_ddp)
ver217's avatar
ver217 committed
145
146
147


if __name__ == '__main__':
148
    test_gpt(4, use_ddp=False)