test_zero_optim.py 5.8 KB
Newer Older
1
2
3
4
5
6
7
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
8
from colossalai.utils.model.colo_init_context import ColoInitContext
9
from colossalai.gemini import ChunkManager
10
from functools import partial
ver217's avatar
ver217 committed
11
from _utils import tensor_equal, set_seed, tensor_shard_equal
12
13
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
14
from colossalai.nn.parallel import ZeroDDP
15
16
17
18
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
ver217's avatar
ver217 committed
19
from colossalai.gemini.gemini_mgr import GeminiManager
20
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
21
22


23
def check_param_equal(model, torch_model, pg: ProcessGroup):
24
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
25
        if p.storage().size() > 0:
26
27
28
            assert p.dtype == torch.float16
            assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
                                      pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'
29
30


31
def check_grad_equal(model, torch_model, pg: ProcessGroup):
32
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
33
        if p.grad is not None:
34
35
36
            assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
                                      pg.tp_local_rank(), pg.tp_world_size()), \
                f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
37
38
39


def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
40
41
42
43
44
45
46
47
    optimizer.zero_grad()
    logits = model(input_ids, attn_mask)
    logits = logits.float()
    loss = criterion(logits, input_ids)
    optimizer.backward(loss)
    return logits


48
def init_1d_row_spec(model, pg: ProcessGroup):
49
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
50
51
52
53
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'weight' in n and 'ln' not in n:
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
54
55


56
def init_1d_col_spec(model, pg: ProcessGroup):
57
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
58
59
60
61
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'ln' not in n and ('weight' in n or 'bias' in n):
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
62
63


64
65
66
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
ver217's avatar
ver217 committed
67
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
68
69
70
71
72
73
    set_seed(42)
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    with ColoInitContext(device=get_current_device()):
        model = model_builder()
74
    model = model.cuda()
75
    torch_model = model_builder().cuda()
76

77
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
78
        torch_p.data.copy_(p.data)
79

80
81
82
83
84
85
86
87
    world_size = torch.distributed.get_world_size()

    # world size, dp = 2, tp =2, construct a hybrid parallelism.
    if world_size == 4:
        pg = ProcessGroup(tp_degree=2)
    else:
        pg = ProcessGroup(tp_degree=world_size)

ver217's avatar
ver217 committed
88
    if tp_init_spec_func:
89
        tp_init_spec_func(model, pg)
ver217's avatar
ver217 committed
90

91
92
    chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
    chunk_manager = ChunkManager(chunk_size,
93
                                 pg,
94
95
96
                                 enable_distributed_storage=use_zero,
                                 init_device=GeminiManager.get_default_device(placement_policy))
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
97
    model = ZeroDDP(model, gemini_manager)
98
    optim = HybridAdam(model.parameters(), lr=1e-3)
99
    optim = ZeroOptimizer(optim, model, initial_scale=1)
100

101
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
102
103
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
104
    torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
105

106
    print(chunk_manager)
107
    check_param_equal(model, torch_model, pg)
108
109
110
111

    model.eval()
    torch_model.eval()

112
    set_seed(pg.dp_local_rank())
113
114
115
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
        if i > 2:
            break
116
117
        input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
        logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask)
118
        torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
119
        assert tensor_equal(logits, torch_logits)
120
        check_grad_equal(model, torch_model, pg)
121
122
        optim.step()
        torch_optim.step()
123
        check_param_equal(model, torch_model, pg)
124
125
126


def run_dist(rank, world_size, port):
ver217's avatar
ver217 committed
127
128
129
130
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    if world_size == 4:
        run_gpt(tp_init_spec_func=init_1d_col_spec)
131
        run_gpt(tp_init_spec_func=init_1d_row_spec)
ver217's avatar
ver217 committed
132
    else:
133
        run_gpt(tp_init_spec_func=init_1d_col_spec)
134
135
136
137
138
139
140
141
142
143
144
145


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_gpt(4)