test_zero_optim.py 5.45 KB
Newer Older
1
2
3
4
5
6
7
8
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
9
from colossalai.utils.model.colo_init_context import ColoInitContext
10
11
12
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
ver217's avatar
ver217 committed
13
from _utils import tensor_equal, set_seed, tensor_shard_equal
14
15
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
16
from colossalai.nn.parallel import ZeroDDP
17
18
19
20
from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.amp import convert_to_apex_amp
ver217's avatar
ver217 committed
21
from colossalai.gemini.gemini_mgr import GeminiManager
ver217's avatar
ver217 committed
22
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
23
24
25
26
27
28


def check_param_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        if p.storage().size() > 0:
            assert p.dtype == torch.half
ver217's avatar
ver217 committed
29
            assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
30
31


32
33
34
def check_grad_equal(model, torch_model):
    for p, torch_p in zip(model.parameters(), torch_model.parameters()):
        if p.grad is not None:
ver217's avatar
ver217 committed
35
            assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
36
37
38


def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
39
40
41
42
43
44
45
46
    optimizer.zero_grad()
    logits = model(input_ids, attn_mask)
    logits = logits.float()
    loss = criterion(logits, input_ids)
    optimizer.backward(loss)
    return logits


ver217's avatar
ver217 committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def init_1d_row_spec(model):
    spec = TensorSpec(
        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
        ParallelAction(ComputePattern.TP1D))
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'weight' in n and 'ln' not in n:
                p.set_spec(spec)


def init_1d_col_spec(model):
    spec = TensorSpec(
        distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
        ParallelAction(ComputePattern.TP1D))
    with DistSpecManager.no_grad():
        for n, p in model.named_parameters():
            if 'ln' not in n and ('weight' in n or 'bias' in n):
                p.set_spec(spec)


67
68
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
69
@parameterize('placement_policy', ['cuda', 'cpu'])
ver217's avatar
ver217 committed
70
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
71
72
73
74
75
76
77
78
79
80
81
    set_seed(42)
    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    with ColoInitContext(device=get_current_device()):
        model = model_builder()
    model = model.cuda().half()
    torch_model = model_builder().cuda()
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p)

ver217's avatar
ver217 committed
82
83
84
    if tp_init_spec_func:
        tp_init_spec_func(model)

85
86
87
88
89
    chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
    chunk_manager = ChunkManager(chunk_size,
                                 enable_distributed_storage=use_zero,
                                 init_device=GeminiManager.get_default_device(placement_policy))
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
90
    model = ZeroDDP(model, gemini_manager)
91
92
93
94
95
96
97
98
    optim = HybridAdam(model.parameters(), lr=1e-3)
    optim = ZeroOptimizer(optim, model, initial_scale=32)

    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))

99
    print(chunk_manager)
100
101
102
103
104
105
106
    check_param_equal(model, torch_model)
    model.train()
    torch_model.train()
    set_seed(gpc.get_local_rank(ParallelMode.DATA))
    for i, (input_ids, attn_mask) in enumerate(train_dataloader):
        if i > 2:
            break
107
108
        logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
        torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
109
        assert tensor_equal(logits, torch_logits)
110
111
112
        check_grad_equal(model, torch_model)
        optim.step()
        torch_optim.step()
113
114
115
116
        check_param_equal(model, torch_model)


def run_dist(rank, world_size, port):
ver217's avatar
ver217 committed
117
118
119
120
121
122
123
124
125
    config = {}
    if world_size == 4:
        config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    if world_size == 4:
        run_gpt(tp_init_spec_func=init_1d_col_spec)
        run_gpt(tp_init_spec_func=init_1d_row_spec)
    else:
        run_gpt()
126
127
128
129
130
131
132
133
134
135
136
137


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_gpt(4)