test_grad_accum.py 5.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
import pytest
import torch
import torch.distributed as dist
from apex import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.nn.optimizer import HybridAdam
10
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
11
12
13
14
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
15
from tests.kit.model_zoo import model_zoo, run_fwd
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

PLACEMENT_CONFIGS = [
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
    {"placement_policy": "auto"},
]


def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
    chunk_manager = model.chunk_manager
    grad_chunk_list = []
    device_list = []

    # Access gradient chunks.
    for p in model.parameters():
        grad_chunk = chunk_manager.get_chunk(p).grad_chunk
        if grad_chunk not in grad_chunk_list:
            chunk_manager.access_chunk(grad_chunk)
            grad_chunk_list.append(grad_chunk)
            device_list.append(model.grads_device[p])

    # Compare gradients.
    for p0, p1 in zip(model.parameters(), torch_model.parameters()):
40
        assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
41
42
43
44
45
46
47
48
49

    # Release gradient chunks and move them to gradient device.
    for grad_chunk, device in zip(grad_chunk_list, device_list):
        chunk_manager.release_chunk(grad_chunk)
        chunk_manager.move_chunk(grad_chunk, device, force_copy=True)


@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True])
50
@parameterize("model_name", ["transformers_gpt_lm"])
51
@parameterize("master_weights", [False, True])
52
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
53
    init_device = get_current_device()
54
55
56
    model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
        iter(model_zoo.get_sub_registry(model_name).values())
    )
57
58

    set_seed(42)
59
    gemini_model = model_builder()
60
61

    set_seed(42)
62
    torch_model = model_builder().cuda()
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
        torch_p.data.copy_(p.data)

    world_size = torch.distributed.get_world_size()
    config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
    config_dict[world_size]["chunk_size"] = 5000
    config_dict[world_size]["keep_gathered"] = keep_gathered
    gemini_model = GeminiDDP(
        gemini_model,
        config_dict,
        init_device,
        pin_memory=True,
        enable_gradient_accumulation=True,
        master_weights=master_weights,
        **placement_config,
    )
    optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
    gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)

    rank = dist.get_rank()

    # setting master_weights to False will cause overflow after optimizer.step()
    amp_config = dict(
        opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True
    )
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config)
    torch_model = DDP(torch_model, device_ids=[rank])

    set_seed(rank)
    accum_iter = 4
94
95
    train_dataloader = DummyDataloader(data_gen_fn)
    for i, data in enumerate(train_dataloader):
96
        delay_unscale = False if (i + 1) % accum_iter == 0 else True
97
        data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
98
99

        set_seed(42 + rank)
100
        torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
101
102
103
104
105
        torch_loss = torch_loss / accum_iter
        with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
            scaled_loss.backward()

        set_seed(42 + rank)
106
        gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
107
108
109
        gemini_loss = gemini_loss / accum_iter
        gemini_optim.backward(gemini_loss)

110
        assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        check_grad(gemini_model, torch_model)

        if (i + 1) % accum_iter == 0:
            torch_optim.step()
            gemini_optim.step()
            torch_optim.zero_grad()

            # check updated param
            torch_dict = torch_model.state_dict()
            gemini_dict = gemini_model.state_dict(only_rank_0=False)

            for key, value in gemini_dict.items():
                torch_key = "module." + key
                torch_value = torch_dict[torch_key].to(value.device).to(value.dtype)
                assert_close(value, torch_value, rtol=1e-3, atol=2e-3)

        if i == accum_iter:
            break


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    exam_gemini_grad_acc()


@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_grad_accumulation():
    spawn(run_dist, 2)


if __name__ == "__main__":
    test_grad_accumulation()