test_grad_accum.py 5.65 KB
Newer Older
1
2
3
4
5
6
7
8
import pytest
import torch
import torch.distributed as dist
from apex import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
9
from colossalai.accelerator import get_accelerator
10
from colossalai.nn.optimizer import HybridAdam
11
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
12
13
14
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.kit.model_zoo import model_zoo, run_fwd
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

PLACEMENT_CONFIGS = [
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
    {"placement_policy": "auto"},
]


def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
    chunk_manager = model.chunk_manager
    grad_chunk_list = []
    device_list = []

    # Access gradient chunks.
    for p in model.parameters():
        grad_chunk = chunk_manager.get_chunk(p).grad_chunk
        if grad_chunk not in grad_chunk_list:
            chunk_manager.access_chunk(grad_chunk)
            grad_chunk_list.append(grad_chunk)
            device_list.append(model.grads_device[p])

    # Compare gradients.
    for p0, p1 in zip(model.parameters(), torch_model.parameters()):
40
        assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
41
42
43
44
45
46
47
48
49

    # Release gradient chunks and move them to gradient device.
    for grad_chunk, device in zip(grad_chunk_list, device_list):
        chunk_manager.release_chunk(grad_chunk)
        chunk_manager.move_chunk(grad_chunk, device, force_copy=True)


@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True])
50
@parameterize("model_name", ["transformers_gpt_lm"])
51
@parameterize("master_weights", [False, True])
52
53
54
55
@parameterize("use_grad_checkpoint", [False, True])
def exam_gemini_grad_acc(
    placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
):
56
    init_device = get_accelerator().get_current_device()
57
58
59
    model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
        iter(model_zoo.get_sub_registry(model_name).values())
    )
60
61

    set_seed(42)
62
    gemini_model = model_builder()
63
64

    set_seed(42)
65
    torch_model = model_builder().cuda()
66
67
68
    for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
        torch_p.data.copy_(p.data)

69
70
71
72
    if use_grad_checkpoint:
        gemini_model.gradient_checkpointing_enable()
        torch_model.gradient_checkpointing_enable()

73
74
75
76
77
78
79
80
81
82
83
84
85
86
    world_size = torch.distributed.get_world_size()
    config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
    config_dict[world_size]["chunk_size"] = 5000
    config_dict[world_size]["keep_gathered"] = keep_gathered
    gemini_model = GeminiDDP(
        gemini_model,
        config_dict,
        init_device,
        pin_memory=True,
        enable_gradient_accumulation=True,
        master_weights=master_weights,
        **placement_config,
    )
    optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
87
    gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
88
89
90
91
92
93
94
95
96
97
98
99
100

    rank = dist.get_rank()

    # setting master_weights to False will cause overflow after optimizer.step()
    amp_config = dict(
        opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True
    )
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)
    torch_model = DDP(torch_model, device_ids=[rank])

    set_seed(rank)
    accum_iter = 4
101
102
    train_dataloader = DummyDataloader(data_gen_fn)
    for i, data in enumerate(train_dataloader):
103
        delay_unscale = False if (i + 1) % accum_iter == 0 else True
104
        data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
105
106

        set_seed(42 + rank)
107
        torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
108
109
110
111
112
        torch_loss = torch_loss / accum_iter
        with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()

        set_seed(42 + rank)
113
        gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
114
115
116
        gemini_loss = gemini_loss / accum_iter
        gemini_optim.backward(gemini_loss)

117
        assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
118
119
120
121

        check_grad(gemini_model, torch_model)

        if (i + 1) % accum_iter == 0:
122
            torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            torch_optim.step()
            gemini_optim.step()
            torch_optim.zero_grad()

            # check updated param
            torch_dict = torch_model.state_dict()
            gemini_dict = gemini_model.state_dict(only_rank_0=False)

            for key, value in gemini_dict.items():
                torch_key = "module." + key
                torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)
                assert_close(value, torch_value, rtol=1e-3, atol=2e-3)

        if i == accum_iter:
            break


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    exam_gemini_grad_acc()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_grad_accumulation():
    spawn(run_dist, 2)


if __name__ == "__main__":
    test_grad_accumulation()