test_schedules.py 8.28 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
import pytest
2
import torch
xingjinliang's avatar
xingjinliang committed
3
4
from pytest_mock import mocker

5
import megatron.core.pipeline_parallel.schedules as schedule
xingjinliang's avatar
xingjinliang committed
6
7
from megatron.core import ModelParallelConfig
from tests.unit_tests.test_utilities import Utils
8
9

rank = Utils.rank
xingjinliang's avatar
xingjinliang committed
10
11


12
13
def test_get_forward_backward_func():
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
xingjinliang's avatar
xingjinliang committed
14
    assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining
15
16
    Utils.destroy_model_parallel()
    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)
xingjinliang's avatar
xingjinliang committed
17
18
19
20
    assert (
        schedule.get_forward_backward_func()
        == schedule.forward_backward_pipelining_without_interleaving
    )
21
    Utils.destroy_model_parallel()
xingjinliang's avatar
xingjinliang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=2,
        pipeline_model_parallel_size=4,
        virtual_pipeline_model_parallel_size=2,
    )
    assert (
        schedule.get_forward_backward_func()
        == schedule.forward_backward_pipelining_with_interleaving
    )
    Utils.destroy_model_parallel()
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=2,
        pipeline_model_parallel_size=2,
        virtual_pipeline_model_parallel_size=4,
    )
    assert (
        schedule.get_forward_backward_func()
        == schedule.forward_backward_pipelining_with_interleaving
    )
41
42
    Utils.destroy_model_parallel()

xingjinliang's avatar
xingjinliang committed
43

44
45
46
def test_deallocate_output_tensor():
    out = torch.tensor([[1, 2, 3], [4, 5, 6]])
    schedule.deallocate_output_tensor(out)
xingjinliang's avatar
xingjinliang committed
47
48
49
    assert out.nelement() == 6


50
51
52
53
54
55
56
def test_forward_backward_func_without_pipeline_parallel(mocker):
    from megatron.core.pipeline_parallel import get_forward_backward_func

    Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)

    def forward_step_func(data_iterator, model):
        import os
xingjinliang's avatar
xingjinliang committed
57

58
        rank = int(os.environ['LOCAL_RANK'])
xingjinliang's avatar
xingjinliang committed
59
60
        dummy_data = torch.ones(1, 4)

61
        def loss_func(output_tensor):
xingjinliang's avatar
xingjinliang committed
62
63
            return rank, {'loss_reduced': rank}

64
65
        return model(dummy_data), loss_func

xingjinliang's avatar
xingjinliang committed
66
    model = torch.nn.Linear(4, 1)
67
    model.model_type = 'unit-test'
xingjinliang's avatar
xingjinliang committed
68

69
70
    def set_input_tensor(input_tensor):
        return None
xingjinliang's avatar
xingjinliang committed
71

72
73
74
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
xingjinliang's avatar
xingjinliang committed
75
    assert schedule.get_forward_backward_func() == schedule.forward_backward_no_pipelining
76
77

    mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)
xingjinliang's avatar
xingjinliang committed
78
    config = ModelParallelConfig(pipeline_model_parallel_size=1)
liangjing's avatar
v1  
liangjing committed
79
80
    model.config = config

81
82
    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
xingjinliang's avatar
xingjinliang committed
83
        data_iterator=range(0, 100),
84
85
        model=[model],
        num_microbatches=4,
liangjing's avatar
v1  
liangjing committed
86
87
        seq_length=None,
        micro_batch_size=None,
xingjinliang's avatar
xingjinliang committed
88
89
90
91
92
93
94
95
96
97
98
        forward_only=True,
    )

    loss_reduced_expected = [
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
    ]

    for i, j in zip(losses_reduced, loss_reduced_expected):
99
        print(losses_reduced)
xingjinliang's avatar
xingjinliang committed
100
101
102
        assert i['loss_reduced'] == j['loss_reduced']
    Utils.destroy_model_parallel()

103
104
105
106
107
108
109
110

def test_forward_backward_func_with_pipeline_parallel(mocker):
    from megatron.core.pipeline_parallel import get_forward_backward_func

    Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=4)

    def forward_step_func(data_iterator, model):
        import os
xingjinliang's avatar
xingjinliang committed
111

112
        rank = int(os.environ['LOCAL_RANK'])
xingjinliang's avatar
xingjinliang committed
113

114
        def loss_func(output_tensor):
xingjinliang's avatar
xingjinliang committed
115
116
117
            return rank, {'loss_reduced': rank}

        return torch.rand(512, 8, 256).cuda(), loss_func
118

xingjinliang's avatar
xingjinliang committed
119
    model = torch.nn.Linear(4, 1)
120
    model.model_type = 'unit-test'
xingjinliang's avatar
xingjinliang committed
121

122
123
    def set_input_tensor(input_tensor):
        return None
xingjinliang's avatar
xingjinliang committed
124

125
126
127
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
xingjinliang's avatar
xingjinliang committed
128
129
130
131
    assert (
        schedule.get_forward_backward_func()
        == schedule.forward_backward_pipelining_without_interleaving
    )
132
133
134
135

    sequence_length = 512
    micro_batch_size = 8
    hidden_size = 256
liangjing's avatar
v1  
liangjing committed
136
137

    config = ModelParallelConfig(
xingjinliang's avatar
xingjinliang committed
138
        pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float
liangjing's avatar
v1  
liangjing committed
139
    )
xingjinliang's avatar
xingjinliang committed
140
    config.hidden_size = hidden_size
liangjing's avatar
v1  
liangjing committed
141
    model.config = config
xingjinliang's avatar
xingjinliang committed
142

143
144
145
146
    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
        data_iterator=None,
        model=[model],
xingjinliang's avatar
xingjinliang committed
147
        num_microbatches=micro_batch_size,
liangjing's avatar
v1  
liangjing committed
148
149
        seq_length=sequence_length,
        micro_batch_size=micro_batch_size,
xingjinliang's avatar
xingjinliang committed
150
151
152
153
154
155
156
157
158
159
        forward_only=True,
    )

    loss_reduced_expected = [
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
    ]
    for i, j in zip(losses_reduced, loss_reduced_expected):
160
        print(losses_reduced)
xingjinliang's avatar
xingjinliang committed
161
162
        assert i['loss_reduced'] == j['loss_reduced']
    Utils.destroy_model_parallel()
163

liangjing's avatar
v1  
liangjing committed
164

165
166
def test_forward_backward_func_with_interleaving(mocker):
    from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
167
    from megatron.core.pipeline_parallel import get_forward_backward_func
168

xingjinliang's avatar
xingjinliang committed
169
170
171
172
173
    Utils.initialize_model_parallel(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=4,
        virtual_pipeline_model_parallel_size=2,
    )
174
175
176

    def forward_step_func(data_iterator, model):
        import os
xingjinliang's avatar
xingjinliang committed
177

178
        rank = int(os.environ['LOCAL_RANK'])
xingjinliang's avatar
xingjinliang committed
179

180
        def loss_func(output_tensor):
xingjinliang's avatar
xingjinliang committed
181
182
183
184
185
            return rank, {'loss_reduced': rank}

        return torch.rand(512, 8, 256).cuda(), loss_func

    model = torch.nn.Linear(4, 1)
186
187
188

    def set_input_tensor(input_tensor):
        return None
xingjinliang's avatar
xingjinliang committed
189

190
191
192
    model.set_input_tensor = set_input_tensor

    forward_backward_func = get_forward_backward_func()
xingjinliang's avatar
xingjinliang committed
193
194
195
196
    assert (
        schedule.get_forward_backward_func()
        == schedule.forward_backward_pipelining_with_interleaving
    )
197
198
199
200
201

    sequence_length = 512
    micro_batch_size = 8
    hidden_size = 256

xingjinliang's avatar
xingjinliang committed
202
203
204
205
206
207
    config = ModelParallelConfig(
        pipeline_model_parallel_size=4, sequence_parallel=False, pipeline_dtype=torch.float
    )
    config.hidden_size = hidden_size
    model.config = config

208
209
210
211
212
213
    mocker.patch("megatron.core.pipeline_parallel.schedules.custom_backward", return_value=2)

    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_and_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
xingjinliang's avatar
xingjinliang committed
214
            data_iterator=[range(0, 100)],
215
            model=[model, model],
xingjinliang's avatar
xingjinliang committed
216
217
218
            num_microbatches=micro_batch_size,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size,
219
            decoder_seq_length=sequence_length,
xingjinliang's avatar
xingjinliang committed
220
221
222
            forward_only=True,
        )

223
224
225
226
    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_or_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
xingjinliang's avatar
xingjinliang committed
227
            data_iterator=[range(0, 100)],
228
            model=[model, model],
xingjinliang's avatar
xingjinliang committed
229
230
231
            num_microbatches=micro_batch_size,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size,
232
            decoder_seq_length=256,
xingjinliang's avatar
xingjinliang committed
233
234
            forward_only=True,
        )
235
236
237
238
239

    with pytest.raises(RuntimeError):
        model.model_type = ModelType.encoder_or_decoder
        forward_backward_func(
            forward_step_func=forward_step_func,
xingjinliang's avatar
xingjinliang committed
240
            data_iterator=[range(0, 100)],
241
            model=[model, model],
xingjinliang's avatar
xingjinliang committed
242
243
244
            num_microbatches=7,
            seq_length=sequence_length,
            micro_batch_size=micro_batch_size,
245
            decoder_seq_length=512,
xingjinliang's avatar
xingjinliang committed
246
247
            forward_only=True,
        )
248
249
250
251

    model.model_type = ModelType.encoder_or_decoder
    losses_reduced = forward_backward_func(
        forward_step_func=forward_step_func,
xingjinliang's avatar
xingjinliang committed
252
        data_iterator=[range(0, 100), range(0, 100)],
253
        model=[model, model],
xingjinliang's avatar
xingjinliang committed
254
255
256
        num_microbatches=micro_batch_size,
        seq_length=sequence_length,
        micro_batch_size=micro_batch_size,
257
        decoder_seq_length=sequence_length,
xingjinliang's avatar
xingjinliang committed
258
259
260
261
262
263
264
265
266
267
        forward_only=True,
    )

    loss_reduced_expected = [
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
        {'loss_reduced': rank},
    ]
    for i, j in zip(losses_reduced, loss_reduced_expected):
268
        print(losses_reduced)
xingjinliang's avatar
xingjinliang committed
269
        assert i['loss_reduced'] == j['loss_reduced']
270

xingjinliang's avatar
xingjinliang committed
271
    Utils.destroy_model_parallel()