test_search.py 5.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
from functools import partial

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

import colossalai
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
12
13
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from tests.components_to_test.registry import non_distributed_component_funcs


def init_1d_row_spec(model, pg: ProcessGroup):
    tensor_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
    for n, p in model.named_parameters():
        if 'weight' in n and 'ln' not in n:
            p.set_process_group(pg)
            p.set_tensor_spec(*tensor_spec)


def exam_search_chunk_size():
    world_size = torch.distributed.get_world_size()
    pg_tp = ProcessGroup(tp_degree=world_size)

    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    # make sure torch_model and model has the same parameter values
    with ColoInitContext(device=get_current_device()):
        model = model_builder()
    init_1d_row_spec(model, pg_tp)
36
37
38
39
40
    config_dict, *_ = search_chunk_configuration(model,
                                                 search_range_mb=1,
                                                 search_interval_byte=16,
                                                 min_chunk_size_mb=0,
                                                 filter_exlarge_params=True)
41
42
43
44
45
46
47
48
49

    for key in config_dict:
        chunk_size = config_dict[key]['chunk_size']
        if world_size == 1:
            assert chunk_size == 31616
        else:
            assert chunk_size == 1024


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
def exam_search_strict_ddp():
    world_size = torch.distributed.get_world_size()
    default_shard_pg = ProcessGroup(tp_degree=world_size)
    default_shard_spec = ShardSpec([-1], [world_size])

    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
    # get the chunk configuration over replicated models
    with ColoInitContext(device=get_current_device()):
        ddp_model = model_builder()
    re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
                                                              search_range_mb=1,
                                                              search_interval_byte=16,
                                                              min_chunk_size_mb=0,
                                                              filter_exlarge_params=True,
                                                              strict_ddp_flag=False)
    # get the chunk configuration over sharded ddp models
    with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
                         default_dist_spec=default_shard_spec):
        sharded_ddp_model = model_builder()
    sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
                                                              search_range_mb=1,
                                                              search_interval_byte=16,
                                                              min_chunk_size_mb=0,
                                                              filter_exlarge_params=True,
                                                              strict_ddp_flag=True)
    assert re_dict == sh_dict
    for key in re_dict:
        assert re_dict[key] == sh_dict[key]

    assert re_total == sh_total
    assert re_wasted == sh_wasted


def exam_chunk_manager():
    world_size = torch.distributed.get_world_size()
    default_shard_pg = ProcessGroup(tp_degree=world_size)
    default_shard_spec = ShardSpec([-1], [world_size])

    get_components_func = non_distributed_component_funcs.get_callable('gpt2')
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()

    with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
                         default_dist_spec=default_shard_spec):
        sharded_ddp_model = model_builder()
    chunk_manager = init_chunk_manager(sharded_ddp_model,
                                       get_current_device(),
                                       hidden_dim=16,
                                       search_range_mb=1,
                                       min_chunk_size_mb=0,
                                       filter_exlarge_params=True,
                                       strict_ddp_flag=True)
    config_dict = chunk_manager.dp_degree_chunk_size_dict
    assert len(config_dict) == 1
    assert config_dict[world_size] == 31616


107
108
109
def run_dist(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    exam_search_chunk_size()
110
111
    exam_search_strict_ddp()
    exam_chunk_manager()
112
113
114
115
116
117
118
119
120
121
122
123


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_search(world_size):
    run_func = partial(run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_search(4)