test_zero_optim.py 5.89 KB
Newer Older
1
2
3
4
5
6
7
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
8
from colossalai.utils.model.colo_init_context import ColoInitContext
9
from colossalai.gemini import ChunkManager
10
from functools import partial
11
from tests.test_tensor.common_utils import tensor_equal, set_seed, tensor_shard_equal
12
13
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
14
from colossalai.nn.parallel import ZeroDDP
15
16
17
18
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
ver217's avatar
ver217 committed
19
from colossalai.gemini.gemini_mgr import GeminiManager
20
from colossalai.tensor import ColoTensorSpec, ShardSpec, ComputePattern, ComputeSpec, ProcessGroup, ColoTensor
21
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
22
23


24
def check_param_equal(model, torch_model, pg: ProcessGroup):
25
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
26
        if p.storage().size() > 0:
27
28
29
            assert p.dtype == torch.float16
            assert tensor_shard_equal(tp.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
                                      pg.tp_world_size()), f'{tp} vs {p}\n{n}:\n\t{tp.shape} vs {p.shape}'
30
31


32
def check_grad_equal(model, torch_model, pg: ProcessGroup):
33
    for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
34
        if p.grad is not None:
35
36
37
            assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
                                      pg.tp_local_rank(), pg.tp_world_size()), \
                f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
38
39
40


def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
41
42
43
44
45
46
47
48
    optimizer.zero_grad()
    logits = model(input_ids, attn_mask)
    logits = logits.float()
    loss = criterion(logits, input_ids)
    optimizer.backward(loss)
    return logits


49
def init_1d_row_spec(model, pg: ProcessGroup):
50
    spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
51
52
53
54
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'weight' in n and 'ln' not in n:
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
55
56


57
def init_1d_col_spec(model, pg: ProcessGroup):
58
    spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
59
60
61
62
    for n, p in model.named_parameters():
        p.set_process_group(pg)
        if 'ln' not in n and ('weight' in n or 'bias' in n):
            p.set_tensor_spec(*spec)
ver217's avatar
ver217 committed
63
64


65
66
67
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
ver217's avatar
ver217 committed
68
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
69
70
71
72
73
74
    set_seed(42)
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    with ColoInitContext(device=get_current_device()):
        model = model_builder()
75
    model = model.cuda()
76
    torch_model = model_builder().cuda()
77

78
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
79
        torch_p.data.copy_(p.data)
80

81
82
83
84
85
86
87
88
    world_size = torch.distributed.get_world_size()

    # world size, dp = 2, tp =2, construct a hybrid parallelism.
    if world_size == 4:
        pg = ProcessGroup(tp_degree=2)
    else:
        pg = ProcessGroup(tp_degree=world_size)

ver217's avatar
ver217 committed
89
    if tp_init_spec_func:
90
        tp_init_spec_func(model, pg)
ver217's avatar
ver217 committed
91

92
93
    chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
    chunk_manager = ChunkManager(chunk_size,
94
                                 pg,
95
96
97
                                 enable_distributed_storage=use_zero,
                                 init_device=GeminiManager.get_default_device(placement_policy))
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
98
    model = ZeroDDP(model, gemini_manager)
99
    optim = HybridAdam(model.parameters(), lr=1e-3)
100
    optim = ZeroOptimizer(optim, model, initial_scale=1)
101

102
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
103
104
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
105
    torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
106

107
    print(chunk_manager)
108
    check_param_equal(model, torch_model, pg)
109
110
111
112

    model.eval()
    torch_model.eval()

113
    set_seed(pg.dp_local_rank())
114
115
116
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
        if i > 2:
            break
117
118
        input_ids_colo = ColoTensor.from_torch_tensor(input_ids, ColoTensorSpec(pg))
        logits = run_fwd_bwd(model, criterion, optim, input_ids_colo, attn_mask)
119
        torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
120
        assert tensor_equal(logits, torch_logits)
121
        check_grad_equal(model, torch_model, pg)
122
123
        optim.step()
        torch_optim.step()
124
        check_param_equal(model, torch_model, pg)
125
126
127


def run_dist(rank, world_size, port):
ver217's avatar
ver217 committed
128
129
130
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    if world_size == 4:
131
        run_gpt(tp_init_spec_func=init_megatron_spec)
ver217's avatar
ver217 committed
132
    else:
133
        run_gpt(tp_init_spec_func=init_1d_col_spec)
134
        run_gpt(tp_init_spec_func=init_1d_row_spec)
135
136
137
138
139
140
141
142
143
144
145
146


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_gpt(4)