test_gpt2.py 4.47 KB
Newer Older
ver217's avatar
ver217 committed
1
import pytest
2

3
4
5
6
7
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
ver217's avatar
ver217 committed
8
import torch.multiprocessing as mp
9
10

import colossalai
ver217's avatar
ver217 committed
11
12
13
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
14
from colossalai.utils.model.colo_init_context import ColoInitContext
15
from colossalai.tensor import ShardSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
16
from colossalai.nn.parallel.data_parallel import ColoDDP
17
18
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
19
from tests.components_to_test.registry import non_distributed_component_funcs
ver217's avatar
ver217 committed
20
21


22
def init_1d_row_spec(model, pg: ProcessGroup):
23
    tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ver217's avatar
ver217 committed
24
25
26
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'weight' in n and 'ln' not in n:
27
                p.set_tensor_spec(*tensor_spec)
ver217's avatar
ver217 committed
28
29


30
def init_1d_col_spec(model, pg: ProcessGroup):
31
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ver217's avatar
ver217 committed
32
33
34
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'ln' not in n and ('weight' in n or 'bias' in n):
35
                p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
36
37


38
def check_param_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
39
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
40
41
42
        assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
        assert pg.tp_world_size() is not None
        assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
43
44


45
def check_grad_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
46
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
47
        assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
48
49


50
def run_gpt(init_spec_func, use_ddp):
51
52
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
53
54
55
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

ver217's avatar
ver217 committed
56
    with ColoInitContext(device=get_current_device()):
57
        model = model_builder()
ver217's avatar
ver217 committed
58
    model = model.cuda()
59
    torch_model = model_builder().cuda()
60
    if use_ddp:
61
62
        # torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
        # torch.distributed.barrier()
63
64
65
        torch_model = DDP(torch_model,
                          device_ids=[gpc.get_global_rank()],
                          process_group=gpc.get_group(ParallelMode.DATA))
66
67

        model = ColoDDP(model, process_group=pg)
ver217's avatar
ver217 committed
68
69
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)
70
71
    init_spec_func(model, pg)
    check_param_equal(model, torch_model, pg)
ver217's avatar
ver217 committed
72
73
    model.train()
    torch_model.train()
74
75
    set_seed(pg.tp_local_rank())

76
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
ver217's avatar
ver217 committed
77
78
        logits = model(input_ids, attn_mask)
        torch_logits = torch_model(input_ids, attn_mask)
79
        assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
ver217's avatar
ver217 committed
80
81
        loss = criterion(logits, input_ids)
        torch_loss = criterion(torch_logits, input_ids)
82
83
84
85
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
ver217's avatar
ver217 committed
86
        torch_loss.backward()
87
        check_grad_equal(model, torch_model, pg)
88
89
        if i > 0:
            break
ver217's avatar
ver217 committed
90
91


92
93
94
95
96
def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
    tp_world_size = world_size // 2 if use_ddp else world_size
    config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
ver217's avatar
ver217 committed
97
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
98
    # run_gpt(init_1d_row_spec, use_ddp)
99
    run_gpt(init_1d_col_spec, use_ddp)
ver217's avatar
ver217 committed
100
101
102


@pytest.mark.dist
103
@pytest.mark.skip("under development")
ver217's avatar
ver217 committed
104
@pytest.mark.parametrize('world_size', [1, 4])
105
@pytest.mark.parametrize('use_ddp', [False, True])
ver217's avatar
ver217 committed
106
@rerun_if_address_is_in_use()
107
108
def test_gpt(world_size, use_ddp):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
ver217's avatar
ver217 committed
109
110
111
112
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
113
    test_gpt(4, True)