test_grad_clip.py 4.63 KB
Newer Older
1
2
3
4
5
6
7
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
8
from colossalai.legacy.amp import convert_to_apex_amp
9
from colossalai.nn.optimizer import HybridAdam
10
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
11
from colossalai.utils import set_seed
12
13
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
14
15
16
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs

17
18
PLACEMENT_CONFIGS = [
    {
19
20
21
22
23
        "placement_policy": "static",
        "shard_param_frac": 0.0,
        "offload_optim_frac": 0.0,
        "offload_param_frac": 0.0,
    },  # zero2
24
    {
25
26
27
28
29
        "placement_policy": "static",
        "shard_param_frac": 0.0,
        "offload_optim_frac": 1.0,
        "offload_param_frac": 0.0,
    },  # zero2-offload
30
    {
31
32
33
34
35
36
        "placement_policy": "static",
        "shard_param_frac": 0.0,
        "offload_optim_frac": 0.5,
        "offload_param_frac": 0.0,
    },  # zero2-offload-half
    {"placement_policy": "auto"},
37
]
38

39
40

def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
41
42
43
44
45
46
47
48
49
50
51
52
    zero_dict = model.state_dict(only_rank_0=False)
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
        temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
        assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


53
54
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
55
56
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
57
58
59
60
61
    set_seed(1912)
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    torch_model = model_builder().cuda()
62
    amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32)
63
64
65
66
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

67
    model = model_builder()
68
69
70
71
72

    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        p.data.copy_(torch_p.data)

    world_size = torch.distributed.get_world_size()
73
    config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
74
75
76
77
    config_dict[world_size]["chunk_size"] = 5000
    config_dict[world_size]["keep_gathered"] = False
    if placement_config["placement_policy"] != "cuda":
        init_device = torch.device("cpu")
78
79
    else:
        init_device = None
80

81
    model = GeminiDDP(
82
83
84
85
        model,
        chunk_config_dict=config_dict,
        chunk_init_device=init_device,
        pin_memory=True,
86
        master_weights=master_weights,
87
        **placement_config,
88
    )
89
90

    optimizer = HybridAdam(model.parameters(), lr=1e-3)
91
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

    model.train()
    torch_model.train()

    set_seed(dist.get_rank() * 3 + 128)
    for i, (data, label) in enumerate(train_dataloader):
        if i > 2:
            break
        data = data.cuda()
        label = label.cuda()

        zero_optim.zero_grad()
        torch_optim.zero_grad()

        torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
108
109
110
111

        # as no master weights leads to error accumulation, we don't check the loss
        if master_weights:
            assert_close(torch_loss, loss)
112
113

        import apex.amp as apex_amp
114

115
116
117
118
        torch.nn.utils.clip_grad_norm_(apex_amp.master_params(torch_optim), 1.0)
        torch_optim.step()
        zero_optim.step()

119
120
        if master_weights:
            check_param(model, torch_model)
121
122
123
124


def run_dist(rank, world_size, port):
    config = {}
125
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
126
127
128
129
    exam_grad_clipping()


@pytest.mark.dist
130
@pytest.mark.parametrize("world_size", [1, 2])
131
132
@rerun_if_address_is_in_use()
def test_grad_clip(world_size):
133
    spawn(run_dist, world_size)
134
135


136
if __name__ == "__main__":
137
    test_grad_clip(2)