test_inference.py 4.75 KB
Newer Older
1
from typing import Callable
2
3
4
5
6
7
8
9

import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close

import colossalai
10
from colossalai.legacy.amp import convert_to_apex_amp
11
from colossalai.nn.optimizer import HybridAdam
12
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
13
from colossalai.utils import set_seed
14
from colossalai.utils.device import get_current_device
15
16
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
17
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
18
19

PLACEMENT_CONFIGS = [
20
21
22
23
    {"placement_policy": "static", "shard_param_frac": 0.0},  # zero2
    {"placement_policy": "static", "shard_param_frac": 1.0},  # zero3
    {"placement_policy": "static", "shard_param_frac": 0.5},  # zero3-half
    {"placement_policy": "auto"},
24
25
26
27
]


def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
28
29
30
31
32
33
34
35
36
37
38
39
    zero_dict = model.state_dict(only_rank_0=False)
    torch_dict = torch_model.state_dict()

    for key, value in torch_dict.items():
        # key is 'module.model.PARAMETER', so we truncate it
        key = key[7:]
        assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
        temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
        # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
        assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)


40
def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
41
    world_size = dist.get_world_size()
42
    config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
43
44
    config_dict[world_size]["chunk_size"] = 5000
    config_dict[world_size]["keep_gathered"] = False
45
    model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config)
46
47
48
    return model


49
50
def single_chunk_init(model: torch.nn.Module, placement_config: dict):
    model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
51
52
53
    return model


54
@parameterize("placement_config", PLACEMENT_CONFIGS)
55
@parameterize("model_name", ["transformers_gpt_lm"])
56
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
57
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
58
    set_seed(19360226)
59
    model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
60
61

    torch_model = model_builder().cuda()
62
    amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
63
64
65
66
    torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
    torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
    torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
    init_dev = get_current_device()
67
    model = model_builder().to(init_dev)
68
69
70
71

    for torch_p, p in zip(torch_model.parameters(), model.parameters()):
        p.data.copy_(torch_p.data)

72
    model = model_init_func(model, placement_config)
73
    optimizer = HybridAdam(model.parameters(), lr=1e-3)
74
    zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128)
75
76
77
78
79

    model.eval()
    torch_model.eval()

    set_seed(dist.get_rank() * 3 + 128)
80
    train_dataloader = iter(DummyDataloader(data_gen_fn))
81
82

    def train_iter():
83
84
        data = next(train_dataloader)
        data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
85
86
        zero_optim.zero_grad()
        torch_optim.zero_grad()
87
88
89
        torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim)
        loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim)
        assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5)
90
91
92
93
94
        zero_optim.step()
        torch_optim.step()
        check_param(model, torch_model)

    def inference_iter():
95
96
        data = next(train_dataloader)
        data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
97
        with torch.no_grad():
98
99
100
            torch_loss = run_fwd(torch_model, data, output_transform_fn)
            zero_loss = run_fwd(model, data, output_transform_fn)
        assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5)
101
102
103
104
105
106
107
108

    train_iter()
    inference_iter()
    train_iter()


def run_dist(rank, world_size, port):
    config = {}
109
    colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
110
111
112
113
    exam_inference()


@pytest.mark.dist
114
@pytest.mark.parametrize("world_size", [1, 4])
115
116
@rerun_if_address_is_in_use()
def test_inference(world_size):
117
    spawn(run_dist, world_size)
118
119


120
if __name__ == "__main__":
121
    test_inference(1)