"vscode:/vscode.git/clone" did not exist on "de6e079f15b0d1988033fc8203d172f493d08506"
test_optim.py 7.22 KB
Newer Older
1
2
3
4
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
HELSON's avatar
HELSON committed
5
from torch.testing import assert_close
6
7

import colossalai
8
from colossalai.legacy.amp import convert_to_apex_amp
9
from colossalai.nn.optimizer import HybridAdam
10
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
11
from colossalai.utils import set_seed
12
from colossalai.utils.cuda import get_current_device
13
14
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.components_to_test import run_fwd_bwd
16
from tests.components_to_test.registry import non_distributed_component_funcs
17
18

PLACEMENT_CONFIGS = [
19
20
21
22
23
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 1.0},  # zero2-offload
    {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.5},  # zero2-offload-half
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
24
    {
25
26
27
28
29
30
        "placement_policy": "static",
        "shard_param_frac": 1.0,
        "offload_optim_frac": 1.0,
        "offload_param_frac": 1.0,
    },  # zero3-offload-all
    {"placement_policy": "auto"},
31
]
32

33
# this model is large enough to slice to chunks
34
TEST_MODELS = ["gpt2"]
35
# these models are too small, all parameters in these models are compacted into one chunk
36
EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"]
37

Hongxin Liu's avatar
Hongxin Liu committed
38
39
# bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [
40
41
42
    "albert.embeddings.word_embeddings.weight",
    "albert.embeddings.position_embeddings.weight",
    "masked_bias",
Hongxin Liu's avatar
Hongxin Liu committed
43
]
44

Hongxin Liu's avatar
Hongxin Liu committed
45

46
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
47
    zero_dict = model.state_dict(only_rank_0=False)
48
49
50
51
52
53
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
Hongxin Liu's avatar
Hongxin Liu committed
54
55
56
57
58
59
        temp_zero_value = zero_dict[key].to(device=value.device)
        if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
            continue
        rtol, atol = 1e-3, 4e-3
        if dtype is torch.bfloat16:
            rtol, atol = 4e-3, 8e-3
60
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
61
62
63
64
65
66
67
68
69
70
71
72
        assert_close(
            value.float(),
            temp_zero_value.float(),
            rtol=rtol,
            atol=atol,
            msg=lambda s: s + f"\n{key}\n{temp_zero_value.dtype}",
        )


@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", TEST_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
73
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype):
74
    set_seed(42)
75
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
76
77
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
78
    torch_model = model_builder().cuda()
79
    amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
HELSON's avatar
HELSON committed
80
81
82
83
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

84
    model = model_builder().cuda()
85

86
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
87
        p.data.copy_(torch_p.data)
88
89

    world_size = torch.distributed.get_world_size()
90
    config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
91
92
    config_dict[world_size]["chunk_size"] = 5000
    config_dict[world_size]["keep_gathered"] = False
93
    model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision)
94
95

    optimizer = HybridAdam(model.parameters(), lr=1e-3)
96
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
97
98
99
100
101

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
Hongxin Liu's avatar
Hongxin Liu committed
102
    rtol, atol = 1e-4, 1e-5
103
    for i, (input_ids, label) in enumerate(train_dataloader):
104
105
        if i > 2:
            break
HELSON's avatar
HELSON committed
106
        input_ids, label = input_ids.cuda(), label.cuda()
107
108
109
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
110
111
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
Hongxin Liu's avatar
Hongxin Liu committed
112
        assert_close(torch_loss, loss, rtol=rtol, atol=atol)
113
114
115
116

        zero_optim.step()
        torch_optim.step()

Hongxin Liu's avatar
Hongxin Liu committed
117
        check_param(model, torch_model, mixed_precision)
118
119


120
121
122
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", EXAMPLE_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
123
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
HELSON's avatar
HELSON committed
124
    set_seed(2008)
125
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
126
127
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

HELSON's avatar
HELSON committed
128
    torch_model = model_builder().cuda()
129
    amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2)
HELSON's avatar
HELSON committed
130
131
132
133
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

134
    model = model_builder().cuda()
135

136
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
HELSON's avatar
HELSON committed
137
        p.data.copy_(torch_p.data)
138

139
140
141
142
143
144
145
146
    model = GeminiDDP(
        model,
        chunk_init_device=get_current_device(),
        search_range_m=1,
        pin_memory=True,
        mixed_precision=mixed_precision,
        **placement_config,
    )
147
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
148
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2)
149
150
151
152
153

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
Hongxin Liu's avatar
Hongxin Liu committed
154
155
156
    rtol, atol = 1.5e-6, 2e-5
    if mixed_precision is torch.bfloat16:
        rtol, atol = 2e-3, 2e-3
157
    for i, (input_ids, label) in enumerate(train_dataloader):
158
159
160
        if i > 2:
            break

HELSON's avatar
HELSON committed
161
162
163
        input_ids = input_ids.cuda()
        label = label.cuda()

164
165
166
        zero_optim.zero_grad()
        torch_optim.zero_grad()

HELSON's avatar
HELSON committed
167
168
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
169
        assert_close(torch_loss, loss, rtol=rtol, atol=atol)  # atol should be 2e-5 for torch lower than 1.12
170
171
172
173

        zero_optim.step()
        torch_optim.step()

Hongxin Liu's avatar
Hongxin Liu committed
174
        check_param(model, torch_model, mixed_precision)
175
176


177
178
def run_dist(rank, world_size, port):
    config = {}
179
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
180
    exam_model_step()
181
    exam_tiny_example()
182
183
184


@pytest.mark.dist
185
@pytest.mark.parametrize("world_size", [1, 4])
186
@rerun_if_address_is_in_use()
187
def test_optim(world_size):
188
    spawn(run_dist, world_size)
189
190


191
if __name__ == "__main__":
HELSON's avatar
HELSON committed
192
    test_optim(1)