"tests/test_gemini/update/test_optim.py" did not exist on "c5d39215f6f5d5e7ace5b9b32450b6a4c2c25711"
test_optim.py 6.23 KB
Newer Older
1
2
3
4
5
6
7
8
from functools import partial
from time import time

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
HELSON's avatar
HELSON committed
9
from torch.testing import assert_close
10
11
12

import colossalai
from colossalai.amp import convert_to_apex_amp
13
from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
14
15
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.optimizer import HybridAdam
16
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
17
18
19
20
21
from colossalai.nn.parallel import ZeroDDP
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
22
from tests.components_to_test import run_fwd_bwd
23
from tests.components_to_test.registry import non_distributed_component_funcs
HELSON's avatar
HELSON committed
24
from tests.test_tensor.common_utils import debug_print, set_seed
25
26
27
28
29
30
31
32
33
34
35
36
37
38


def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
    zero_dict = model.state_dict(only_rank_0=False)
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        if key == 'model.lm_head.weight':
            continue
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
        temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
HELSON's avatar
HELSON committed
39
        assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-2)
40
41


42
43
# 'gpt2', 'bert',
TEST_MODELS = ['gpt2', 'bert']
HELSON's avatar
HELSON committed
44
EXAMPLE_MODELS = ['simple_net']
45
46


HELSON's avatar
HELSON committed
47
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
48
49
@parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str):
50
    set_seed(42)
51
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
52
53
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
54
55
56
57
58
59
    torch_model = model_builder().cuda()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

60
61
62
    with ColoInitContext(device=get_current_device()):
        model = model_builder()
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
63
        p.data.copy_(torch_p.data)
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    world_size = torch.distributed.get_world_size()
    config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
    config_dict[world_size]['chunk_size'] = 5000
    config_dict[world_size]['keep_gathered'] = False
    if placement_policy != 'cuda':
        init_device = torch.device('cpu')
    else:
        init_device = None
    chunk_manager = ChunkManager(config_dict, init_device=init_device)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True)

    optimizer = HybridAdam(model.parameters(), lr=1e-3)
HELSON's avatar
HELSON committed
78
    zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
79
80
81
82
83

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
84
    for i, (input_ids, label) in enumerate(train_dataloader):
85
86
        if i > 2:
            break
HELSON's avatar
HELSON committed
87
        input_ids, label = input_ids.cuda(), label.cuda()
88
89
90
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
91
92
93
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
        assert_close(torch_loss, loss)
94
95
96
97
98
99
100

        zero_optim.step()
        torch_optim.step()

        check_param(model, torch_model)


101
@parameterize('placement_policy', ['cuda', 'cpu'])
HELSON's avatar
HELSON committed
102
@parameterize('model_name', EXAMPLE_MODELS)
103
def exam_tiny_example(placement_policy, model_name: str):
HELSON's avatar
HELSON committed
104
    set_seed(2008)
105
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
106
107
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
108
109
110
111
112
113
    torch_model = model_builder().cuda()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

114
115
116
    with ColoInitContext(device=get_current_device()):
        model = model_builder()
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
117
        p.data.copy_(torch_p.data)
118
119
120
121
122
123
124
125
126
127
128

    chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True)
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
    zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
129
    for i, (input_ids, label) in enumerate(train_dataloader):
130
131
132
        if i > 2:
            break

HELSON's avatar
HELSON committed
133
134
135
        input_ids = input_ids.cuda()
        label = label.cuda()

136
137
138
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
139
140
141
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
        assert_close(torch_loss, loss)
142
143
144
145
146
147
148

        zero_optim.step()
        torch_optim.step()

        check_param(model, torch_model)


149
150
151
def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
152
    exam_model_step()
153
    exam_tiny_example()
154
155
156
157
158


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
159
def test_optim(world_size):
160
161
162
163
164
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
HELSON's avatar
HELSON committed
165
    test_optim(1)