test_gpt2.py 5.56 KB
Newer Older
1
2
from functools import partial

3
import pytest
4
import torch
ver217's avatar
ver217 committed
5
import torch.multiprocessing as mp
6
from torch.nn.parallel import DistributedDataParallel as DDP
7
8

import colossalai
9
10
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
ver217's avatar
ver217 committed
11
12
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
13
from colossalai.utils.cuda import get_current_device
14
from colossalai.utils.model.colo_init_context import ColoInitContext
15
from tests.components_to_test.registry import non_distributed_component_funcs
16
17
18
19
20
21
22
23
from tests.test_tensor.common_utils import (
    debug_print,
    set_seed,
    split_param_col_tp1d,
    split_param_row_tp1d,
    tensor_equal,
    tensor_shard_equal,
)
ver217's avatar
ver217 committed
24
25


26
def init_1d_row_spec(model, pg: ProcessGroup):
27
    tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
28
29
30
31
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'weight' in n and 'ln' not in n:
            p.set_tensor_spec(*tensor_spec)
ver217's avatar
ver217 committed
32
33


34
def init_1d_col_spec(model, pg: ProcessGroup):
35
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
36
37
38
39
40

    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'ln' not in n and ('weight' in n or 'bias' in n):
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def init_megatron_spec(model, pg: ProcessGroup):
    for mn, module in model.named_modules():
        # debug_print([0], mn)
        for pn, param in module.named_parameters(recurse=False):
            # debug_print([0], '\t', pn, param.compute_spec, param.shape)
            param.set_process_group(pg)

            if 'mlp.c_fc' in mn:
                if 'weight' in pn or 'bias' in pn:
                    split_param_col_tp1d(param, pg)
                    param.compute_spec.set_output_replicate(False)
                else:
                    raise RuntimeError
            elif 'mlp.c_proj' in mn:
                if 'weight' in pn:
                    split_param_row_tp1d(param, pg)
                else:
                    assert 'bias' in pn
            elif 'wte' in mn or 'wpe' in mn:
                assert 'weight' in pn
                split_param_col_tp1d(param, pg)
64
            elif 'c_attn' in mn or 'c_proj' in mn:
65
66
67
68
                split_param_col_tp1d(param, pg)
            # debug_print([0], '\t', param.compute_spec, param.shape)


69
def check_param_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
70
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
71
72
73
        assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
        assert pg.tp_world_size() is not None
        assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
74
75


76
def check_grad_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
77
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
78
        assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
79
80


81
def run_gpt(init_spec_func, use_ddp):
82
    world_size = torch.distributed.get_world_size()
83
84

    # build a PG with TP and DP hybrid
85
    pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
86
87
88
89

    # set seed make processes of the same tp group use the same seed
    # set_seed(pg.tp_local_rank())

90
91
92
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

93
    # make sure torch_model and model has the same parameter values
ver217's avatar
ver217 committed
94
    with ColoInitContext(device=get_current_device()):
95
        model = model_builder()
ver217's avatar
ver217 committed
96
    model = model.cuda()
97
    torch_model = model_builder().cuda()
98

99
100
    if use_ddp:
        torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
101
        model = ColoDDP(model, process_group=pg)
102

ver217's avatar
ver217 committed
103
104
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)
105

106
    init_spec_func(model, pg)
107

108
109
    check_param_equal(model, torch_model, pg)

110
111
112
113
114
    # close the dropout in eval mode
    model.eval()
    torch_model.eval()
    set_seed(pg.dp_local_rank())
    torch.distributed.barrier()
115
    for i, (input_ids, label) in enumerate(train_dataloader):
116
        colo_input = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
117
118
        logits = model(colo_input)
        torch_logits = torch_model(input_ids)
119
        assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
ver217's avatar
ver217 committed
120
121
        loss = criterion(logits, input_ids)
        torch_loss = criterion(torch_logits, input_ids)
122
123
124
125
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
ver217's avatar
ver217 committed
126
        torch_loss.backward()
127
        check_grad_equal(model, torch_model, pg)
128
129
        if i > 0:
            break
130
    set_seed(313)
ver217's avatar
ver217 committed
131
132


133
134
135
def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
136
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
137
138
139
140
    # Comments below tests for speed concern
    # run_gpt(init_1d_row_spec, use_ddp)
    # run_gpt(init_1d_col_spec, use_ddp)
    run_gpt(init_megatron_spec, use_ddp)
ver217's avatar
ver217 committed
141
142
143
144


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
145
@pytest.mark.parametrize('use_ddp', [False, True])
ver217's avatar
ver217 committed
146
@rerun_if_address_is_in_use()
147
148
def test_gpt(world_size, use_ddp):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
ver217's avatar
ver217 committed
149
150
151
152
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
153
    test_gpt(4, use_ddp=False)