"vscode:/vscode.git/clone" did not exist on "97933b67107c6b9f2929a7b563d2d8c1530abbda"
test_module_spec.py 8.75 KB
Newer Older
1
from copy import deepcopy
2
3
from functools import partial

4
import pytest
5
6
import torch
import torch.multiprocessing as mp
7
8

import colossalai
9
10
11
12
13
14
15
16
17
18
19
from colossalai.nn.parallel.layers import check_colo_module, init_colo_module
from colossalai.tensor import (
    ColoTensor,
    ColoTensorSpec,
    ComputePattern,
    ComputeSpec,
    ProcessGroup,
    ReplicaSpec,
    ShardSpec,
    distspec,
)
20
21
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
22
from colossalai.utils.cuda import get_current_device
23
from colossalai.zero import ColoInitContext
24
from tests.components_to_test.registry import non_distributed_component_funcs
25
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
26

27

28
29
def run_model_with_spec(mode, model_name):
    get_components_func = non_distributed_component_funcs.get_callable(model_name)
30
    model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
31
32
33
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
    rank = pg.rank()
34
35
36

    set_seed(1)
    with ColoInitContext(device=get_current_device()):
37
        model = model_builder(checkpoint=False)
38

39
    if rank == 0:
40
        model_seq = model_builder(checkpoint=False)
41
42
43
44
45
46
        model_seq = model_seq.cuda()

        # Make two models have the same init params
        for p1, p2 in zip(model.parameters(), model_seq.parameters()):
            p2.data.copy_(p1.data)

47
    compute_spec = ComputeSpec(ComputePattern.TP1D)
48
49
50
51
    # Not all layers in Bert can be mod by 4.
    # e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
    if 'bert' == model_name:
        if 'col' == mode:
52
53
54
            init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
            init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
            init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
55
        elif 'row' == mode:
56
57
58
            init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
            init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
            init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
59
    elif 'simple_net' == model_name:
60
        init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
61
62
63
64
65
66

    model = model.cuda()
    for i, (data, label) in enumerate(train_dataloader):
        data = data.to(get_current_device())
        label = label.to(get_current_device())

67
68
        torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
        torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

        if criterion:
            output = model(data)
            loss = criterion(output, label)
        else:
            output = model(data, label)
            loss = output

        # For reference
        if rank == 0:
            if criterion:
                output_seq = model_seq(data)
                loss_seq = criterion(output_seq, label)
            else:
                output_seq = model_seq(data, label)
                loss_seq = output_seq

        if rank == 0:
            with torch.no_grad():
                assert torch.allclose(loss, loss_seq, rtol=1e-2)

        loss.backward()

        if rank == 0:
            loss_seq.backward()

            with torch.no_grad():
                # check param
                for p1, p2 in zip(model.parameters(), model_seq.parameters()):
                    if p1.size() == p2.size():
                        assert torch.allclose(p1, p2)
                    else:
                        if p1.size(-1) < p2.size(-1):    # col
                            world_size = p2.size(-1) // p1.size(-1)
                            split_p2 = torch.chunk(p2, world_size, dim=-1)[0]

                        elif p1.size(0) < p2.size(0):    # row
                            world_size = p2.size(0) // p1.size(0)
                            split_p2 = torch.chunk(p2, world_size, dim=0)[0]

                        assert torch.allclose(p1, split_p2)

        if i > 3:
            break

114

115
def run_linear_with_spec(mode):
116
117
118
    with ColoInitContext(device=get_current_device()):
        model = torch.nn.Linear(4, 8)

119
    model_handy = deepcopy(model)
120
121
122
123
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
    compute_spec = ComputeSpec(ComputePattern.TP1D)
    init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
124

125
    x = torch.rand(2, 4).cuda()
126
127
    colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))

128
    out = model(x)
129
    colo_out = model_handy(colo_x)
130
    assert tensor_equal(out, colo_out)
131

132
133
134
    grad = torch.rand_like(out)
    out.backward(grad)
    colo_out.backward(grad)
135
136
137

    assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
    assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
138

139

140
def run_check_shared_param():
141
    from transformers import BertConfig, BertForMaskedLM
142
143
144
145
146
147
    hidden_dim = 8
    num_head = 4
    sequence_length = 12
    num_layer = 2
    vocab_size = 24

148
149
150
151
    world_size = torch.distributed.get_world_size()
    pg = ProcessGroup(tp_degree=world_size)
    rank = pg.rank()

152
153
154
155
156
157
158
159
    config = BertConfig(vocab_size=vocab_size,
                        hidden_size=hidden_dim,
                        intermediate_size=hidden_dim * 4,
                        num_attention_heads=num_head,
                        max_position_embeddings=sequence_length,
                        num_hidden_layers=num_layer,
                        hidden_dropout_prob=0.,
                        attention_probs_dropout_prob=0.)
160
    with ColoInitContext(device=get_current_device()):
161
        model = BertForMaskedLM(config)
162

163
    model = model.cuda()
164
    compute_spec = ComputeSpec(ComputePattern.TP1D)
165
166
167
    # model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
    assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
    # They are all Linear, so both row is allowed. This should pass check.
168
    init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
169
    # This should be detected by check because you can not set weight as row while set bias as col.
170
    col_spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
171
172
173
174

    # TODO(jiaruifang) optimize this line
    if not model.cls.predictions.bias.has_initialized:
        model.cls.predictions.bias.pg = pg
175
        model.cls.predictions.bias.dist_spec = ReplicaSpec()
176
177
        model.cls.predictions.bias.has_initialized = True
    model.cls.predictions.bias.set_tensor_spec(*col_spec)
178
    try:
179
        check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
180
181
182
    except Exception as e:
        assert 'incorrectly sharded' in str(e)

183

184
def run_dist(rank, world_size, port):
185
186
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
187
188
    run_linear_with_spec('col')
    run_linear_with_spec('row')
189

190

191
192
193
194
195
196
197
def run_dist_model(rank, world_size, port):
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    for model_name in ['simple_net', 'bert']:
        run_model_with_spec('col', model_name)
        run_model_with_spec('row', model_name)

198

199
200
201
202
def run_dist_check(rank, world_size, port):
    config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
    colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    run_check_shared_param()
203

204

205
206
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
207
@pytest.mark.skip("for higher testing speed")
208
209
@rerun_if_address_is_in_use()
def test_module_linear_1d(world_size):
210
    run_func = partial(run_dist, world_size=world_size, port=free_port())
211
212
    mp.spawn(run_func, nprocs=world_size)

213

214
215
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
216
@pytest.mark.skip("for higher testing speed")
217
@rerun_if_address_is_in_use()
218
219
220
221
def test_module_model(world_size):
    run_func = partial(run_dist_model, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)

222

223
224
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
225
@pytest.mark.skip("for higher testing speed")
226
227
228
@rerun_if_address_is_in_use()
def test_module_check(world_size):
    run_func = partial(run_dist_check, world_size=world_size, port=free_port())
229
230
    mp.spawn(run_func, nprocs=world_size)

231

232
if __name__ == '__main__':
233
    test_module_linear_1d(4)