test_fwd_bwd.py 6.27 KB
Newer Older
1
2
3
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
HELSON's avatar
HELSON committed
4
from torch.testing import assert_close
5
6
7

import colossalai
from colossalai.amp import convert_to_apex_amp
HELSON's avatar
HELSON committed
8
from colossalai.nn.optimizer import HybridAdam
9
from colossalai.tensor import ProcessGroup
10
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
11
from colossalai.utils.cuda import get_current_device
12
13
14
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
15
from tests.components_to_test import run_fwd, run_fwd_bwd
16
from tests.components_to_test.registry import non_distributed_component_funcs
17
from tests.test_tensor.common_utils import set_seed
18
19
20
21
22
23
24
25
26
27


def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
    chunk_manager = model.chunk_manager
    param_list = [p for p in model.parameters()]
    chunk_list = chunk_manager.get_chunks(param_list)
    for chunk in chunk_list:
        chunk_manager.access_chunk(chunk)

    for (p0, p1) in zip(model.parameters(), torch_model.parameters()):
HELSON's avatar
HELSON committed
28
        assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
29
30
31


@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
32
@parameterize('keep_gather', [False, True])
33
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
34
@parameterize('use_grad_checkpoint', [False, True])
35
36
37
38
39
40
41
def exam_gpt_fwd_bwd(
    placement_policy,
    keep_gather,
    model_name: str,
    use_grad_checkpoint: bool = False,
):
    init_device = get_current_device()
42
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
43
44
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

45
46
    set_seed(42)
    with ColoInitContext(device=init_device):
47
        model = model_builder(use_grad_checkpoint)
48

49
    set_seed(42)
50
    torch_model = model_builder(use_grad_checkpoint).cuda()
51
52
53
54
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p.data)

    world_size = torch.distributed.get_world_size()
55
    config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
56
    config_dict[world_size]['chunk_size'] = 5000
57
    config_dict[world_size]['keep_gathered'] = keep_gather
58
59
60
    chunk_manager = ChunkManager(config_dict)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True)
HELSON's avatar
HELSON committed
61
62
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
    zero_optim = ZeroOptimizer(optimizer, model, initial_scale=1)
63
64
65
66
67
68
69
70

    pg = ProcessGroup()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())

    set_seed(pg.dp_local_rank())
71
    for i, (input_ids, label) in enumerate(train_dataloader):
72
73
        # you can only test a single fwd + bwd.
        # after bwd param is grad for Gemini, due to the chunk reuse optimization.
74
75
        if i > 0:
            break
HELSON's avatar
HELSON committed
76
        input_ids, label = input_ids.cuda(), label.cuda()
77
78
79
80
81
82

        torch_optim.zero_grad()
        zero_optim.zero_grad()

        # set random seed is same as torch_model.eval()
        set_seed(42)
HELSON's avatar
HELSON committed
83
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
84
        set_seed(42)
HELSON's avatar
HELSON committed
85
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
86

HELSON's avatar
HELSON committed
87
        assert torch.equal(torch_loss, loss)
88

89
        check_grad(model, torch_model)
90
91


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('scatter_after_inference', [False, True])
def exam_gpt_inference(
    placement_policy,
    keep_gather,
    model_name: str,
    scatter_after_inference: bool = False,
):
    init_device = get_current_device()
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    set_seed(42)
    with ColoInitContext(device=init_device):
        model = model_builder()

    set_seed(42)
    torch_model = model_builder().cuda()
    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        torch_p.data.copy_(p.data)

    world_size = torch.distributed.get_world_size()
    config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
    config_dict[world_size]['chunk_size'] = 5000
    config_dict[world_size]['keep_gathered'] = keep_gather
    chunk_manager = ChunkManager(config_dict)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True, scatter_after_inference=scatter_after_inference)

    pg = ProcessGroup()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())

    set_seed(pg.dp_local_rank())
    model.eval()
    torch_model.eval()
    for i, (input_ids, label) in enumerate(train_dataloader):
        # you can only test a single fwd + bwd.
        # after bwd param is grad for Gemini, due to the chunk reuse optimization.
        if i > 0:
            break
        with torch.no_grad():
            input_ids, label = input_ids.cuda(), label.cuda()

            torch_loss = run_fwd(torch_model, input_ids, label, criterion)
            loss = run_fwd(model, input_ids, label, criterion)

        assert torch.equal(torch_loss, loss)


146
147
148
149
def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_gpt_fwd_bwd()
150
    exam_gpt_inference()
151
152
153
154
155
156


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
157
    spawn(run_dist, world_size)
158
159
160


if __name__ == '__main__':
HELSON's avatar
HELSON committed
161
    test_gpt(4)