test_gpt.py 4.52 KB
Newer Older
ver217's avatar
ver217 committed
1
import pytest
2

ver217's avatar
ver217 committed
3
4
5
6
7
import colossalai
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
8
from colossalai.utils.model.colo_init_context import ColoInitContext
9
10
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup

ver217's avatar
ver217 committed
11
from functools import partial
12
from _utils import tensor_equal, tensor_shard_equal, set_seed
13
from tests.components_to_test.registry import non_distributed_component_funcs
14
import torch
15
from torch.nn.parallel import DistributedDataParallel as DDP
16
from colossalai.nn.parallel.data_parallel import ColoDDP
17
18
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
ver217's avatar
ver217 committed
19
20


21
22
def init_1d_row_spec(model, pg: ProcessGroup):
    tensor_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ver217's avatar
ver217 committed
23
24
25
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'weight' in n and 'ln' not in n:
26
                p.set_tensor_spec(tensor_spec)
ver217's avatar
ver217 committed
27
28


29
30
def init_1d_col_spec(model, pg: ProcessGroup):
    spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
ver217's avatar
ver217 committed
31
32
33
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'ln' not in n and ('weight' in n or 'bias' in n):
34
                p.set_tensor_spec(spec)
ver217's avatar
ver217 committed
35
36


37
def check_param_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
38
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
39
40
41
        assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
        assert pg.tp_world_size() is not None
        assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
42
43


44
def check_grad_equal(model, torch_model, pg: ProcessGroup):
ver217's avatar
ver217 committed
45
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
46
        assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
ver217's avatar
ver217 committed
47
48


49
def run_gpt(init_spec_func, use_ddp):
50
51
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
52
53
54
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

ver217's avatar
ver217 committed
55
    with ColoInitContext(device=get_current_device()):
56
        model = model_builder()
ver217's avatar
ver217 committed
57
    model = model.cuda()
58
    torch_model = model_builder().cuda()
59
    if use_ddp:
60
61
        # torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
        # torch.distributed.barrier()
62
63
64
        torch_model = DDP(torch_model,
                          device_ids=[gpc.get_global_rank()],
                          process_group=gpc.get_group(ParallelMode.DATA))
65
66

        model = ColoDDP(model, process_group=pg)
ver217's avatar
ver217 committed
67
68
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)
69
70
    init_spec_func(model, pg)
    check_param_equal(model, torch_model, pg)
ver217's avatar
ver217 committed
71
72
    model.train()
    torch_model.train()
73
74
    set_seed(pg.tp_local_rank())

75
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
ver217's avatar
ver217 committed
76
77
        logits = model(input_ids, attn_mask)
        torch_logits = torch_model(input_ids, attn_mask)
78
        assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
ver217's avatar
ver217 committed
79
80
        loss = criterion(logits, input_ids)
        torch_loss = criterion(torch_logits, input_ids)
81
82
83
84
        if use_ddp:
            model.backward(loss)
        else:
            loss.backward()
ver217's avatar
ver217 committed
85
        torch_loss.backward()
86
        check_grad_equal(model, torch_model, pg)
87
88
        if i > 0:
            break
ver217's avatar
ver217 committed
89
90


91
92
93
94
95
def run_dist(rank, world_size, port, use_ddp):
    if use_ddp and world_size == 1:
        return
    tp_world_size = world_size // 2 if use_ddp else world_size
    config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
ver217's avatar
ver217 committed
96
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
97
    # run_gpt(init_1d_row_spec, use_ddp)
98
    run_gpt(init_1d_col_spec, use_ddp)
ver217's avatar
ver217 committed
99
100
101


@pytest.mark.dist
102
@pytest.mark.skip("under development")
ver217's avatar
ver217 committed
103
@pytest.mark.parametrize('world_size', [1, 4])
104
@pytest.mark.parametrize('use_ddp', [False, True])
ver217's avatar
ver217 committed
105
@rerun_if_address_is_in_use()
106
107
def test_gpt(world_size, use_ddp):
    run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
ver217's avatar
ver217 committed
108
109
110
111
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
112
    test_gpt(4, True)