test_optim.py 7.45 KB
Newer Older
1
2
3
4
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
HELSON's avatar
HELSON committed
5
from torch.testing import assert_close
6
7

import colossalai
8
from colossalai.legacy.amp import convert_to_apex_amp
9
from colossalai.nn.optimizer import HybridAdam
10
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
11
from colossalai.utils import set_seed
12
from colossalai.utils.cuda import get_current_device
13
14
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.components_to_test import run_fwd_bwd
16
from tests.components_to_test.registry import non_distributed_component_funcs
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

PLACEMENT_CONFIGS = [
    {
        'placement_policy': 'static',
        'shard_param_frac': 0.0,
        'offload_optim_frac': 0.0
    },    # zero2
    {
        'placement_policy': 'static',
        'shard_param_frac': 0.0,
        'offload_optim_frac': 1.0
    },    # zero2-offload
    {
        'placement_policy': 'static',
        'shard_param_frac': 0.0,
        'offload_optim_frac': 0.5
    },    # zero2-offload-half
    {
        'placement_policy': 'static',
        'shard_param_frac': 1.0
    },    # zero3
    {
        'placement_policy': 'static',
        'shard_param_frac': 0.5
    },    # zero3-half
    {
        'placement_policy': 'static',
        'shard_param_frac': 1.0,
        'offload_optim_frac': 1.0,
        'offload_param_frac': 1.0
    },    # zero3-offload-all
    {
        'placement_policy': 'auto'
    }
]
52

53
54
55
# this model is large enough to slice to chunks
TEST_MODELS = ['gpt2']
# these models are too small, all parameters in these models are compacted into one chunk
56
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
57

Hongxin Liu's avatar
Hongxin Liu committed
58
59
60
61
62
63
# bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [
    'albert.embeddings.word_embeddings.weight',
    'albert.embeddings.position_embeddings.weight',
    'masked_bias',
]
64

Hongxin Liu's avatar
Hongxin Liu committed
65

66
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
Hongxin Liu's avatar
Hongxin Liu committed
67
    zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
68
69
70
71
72
73
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
Hongxin Liu's avatar
Hongxin Liu committed
74
75
76
77
78
79
        temp_zero_value = zero_dict[key].to(device=value.device)
        if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
            continue
        rtol, atol = 1e-3, 4e-3
        if dtype is torch.bfloat16:
            rtol, atol = 4e-3, 8e-3
80
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
Hongxin Liu's avatar
Hongxin Liu committed
81
82
83
84
85
        assert_close(value.float(),
                     temp_zero_value.float(),
                     rtol=rtol,
                     atol=atol,
                     msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
86
87


88
@parameterize('placement_config', PLACEMENT_CONFIGS)
89
@parameterize('model_name', TEST_MODELS)
Hongxin Liu's avatar
Hongxin Liu committed
90
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
91
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
92
    set_seed(42)
93
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
94
95
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
96
97
98
99
100
101
    torch_model = model_builder().cuda()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

102
    model = model_builder().cuda()
103

104
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
105
        p.data.copy_(torch_p.data)
106
107

    world_size = torch.distributed.get_world_size()
108
    config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
109
110
    config_dict[world_size]['chunk_size'] = 5000
    config_dict[world_size]['keep_gathered'] = False
111
    model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
112
113

    optimizer = HybridAdam(model.parameters(), lr=1e-3)
114
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
115
116
117
118
119

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
Hongxin Liu's avatar
Hongxin Liu committed
120
    rtol, atol = 1e-4, 1e-5
121
    for i, (input_ids, label) in enumerate(train_dataloader):
122
123
        if i > 2:
            break
HELSON's avatar
HELSON committed
124
        input_ids, label = input_ids.cuda(), label.cuda()
125
126
127
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
128
129
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
Hongxin Liu's avatar
Hongxin Liu committed
130
        assert_close(torch_loss, loss, rtol=rtol, atol=atol)
131
132
133
134

        zero_optim.step()
        torch_optim.step()

Hongxin Liu's avatar
Hongxin Liu committed
135
        check_param(model, torch_model, mixed_precision)
136
137


138
@parameterize('placement_config', PLACEMENT_CONFIGS)
139
@parameterize('model_name', EXAMPLE_MODELS)
Hongxin Liu's avatar
Hongxin Liu committed
140
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
141
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
HELSON's avatar
HELSON committed
142
    set_seed(2008)
143
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
144
145
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
146
147
148
149
150
151
    torch_model = model_builder().cuda()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=2)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

152
    model = model_builder().cuda()
153

154
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
155
        p.data.copy_(torch_p.data)
156

157
158
159
160
161
162
    model = GeminiDDP(model,
                      chunk_init_device=get_current_device(),
                      search_range_m=1,
                      pin_memory=True,
                      mixed_precision=mixed_precision,
                      **placement_config)
163
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
164
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
165
166
167
168
169

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
Hongxin Liu's avatar
Hongxin Liu committed
170
171
172
    rtol, atol = 1.5e-6, 2e-5
    if mixed_precision is torch.bfloat16:
        rtol, atol = 2e-3, 2e-3
173
    for i, (input_ids, label) in enumerate(train_dataloader):
174
175
176
        if i > 2:
            break

HELSON's avatar
HELSON committed
177
178
179
        input_ids = input_ids.cuda()
        label = label.cuda()

180
181
182
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
183
184
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
Hongxin Liu's avatar
Hongxin Liu committed
185
        assert_close(torch_loss, loss, rtol=rtol, atol=atol)    # atol should be 2e-5 for torch lower than 1.12
186
187
188
189

        zero_optim.step()
        torch_optim.step()

Hongxin Liu's avatar
Hongxin Liu committed
190
        check_param(model, torch_model, mixed_precision)
191
192


193
194
195
def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
196
    exam_model_step()
197
    exam_tiny_example()
198
199
200
201
202


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
203
def test_optim(world_size):
204
    spawn(run_dist, world_size)
205
206
207


if __name__ == '__main__':
HELSON's avatar
HELSON committed
208
    test_optim(1)