test_inference.py 5.25 KB
Newer Older
1
from functools import partial
2
from typing import Callable
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
17
18
19
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper
from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_tensor.common_utils import debug_print, set_seed


def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
    zero_dict = model.state_dict(only_rank_0=False)
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
        temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
        assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
    world_size = dist.get_world_size()
    config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
    config_dict[world_size]['chunk_size'] = 5000
    config_dict[world_size]['keep_gathered'] = False
    if placement_policy != 'cuda':
        init_device = torch.device('cpu')
    else:
        init_device = None
    chunk_manager = ChunkManager(config_dict, init_device=init_device)
    gemini_manager = GeminiManager(placement_policy, chunk_manager)
    model = ZeroDDP(model, gemini_manager, pin_memory=True)
    return model


def single_chunk_init(model: torch.nn.Module, placement_policy: str):
    gemini_config = dict(
        device=get_current_device(),
        placement_policy=placement_policy,
        pin_memory=True,
    )
    model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
    return model


63
64
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', ['gpt2'])
65
66
@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    set_seed(19360226)
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    torch_model = model_builder().cuda()
    amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])

    init_dev = get_current_device()
    with ColoInitContext(device=init_dev):
        model = model_builder()

    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        p.data.copy_(torch_p.data)

84
    model = model_init_func(model, placement_policy)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
    zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
    train_dataloader = iter(train_dataloader)

    def train_iter():
        input_ids, label = next(train_dataloader)
        input_ids, label = input_ids.cuda(), label.cuda()
        zero_optim.zero_grad()
        torch_optim.zero_grad()
        torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
        loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
        assert_close(torch_loss, loss)
        zero_optim.step()
        torch_optim.step()
        check_param(model, torch_model)

    def inference_iter():
        input_ids, label = next(train_dataloader)
        input_ids, label = input_ids.cuda(), label.cuda()
        with torch.no_grad():
            torch_output = torch_model(input_ids)
            torch_loss = criterion(torch_output.float(), label)
            zero_output = model(input_ids)
            zero_loss = criterion(zero_output.float(), label)
        assert_close(torch_loss, zero_loss)

    train_iter()
    inference_iter()
    train_iter()


def run_dist(rank, world_size, port):
    config = {}
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_inference()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_inference(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_inference(1)