"docs/source/en/optimization/xformers.mdx" did not exist on "acd317810bc138b3a78fa30e1b3da1006c1d60ad"
test_configurable_parallel_pp.py 11.2 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17

import os
import torch
import deepspeed
import pytest
import random
import numpy as np
import deepspeed.comm as dist
from unit.common import DistributedTest, DistributedFixture
from unit.megatron_model import get_megatron_version
from unit.megatron_model import MockGPT2ModelPipe as GPT2ModelPipe
from deepspeed.utils import RepeatingLoader
from deepspeed.accelerator import get_accelerator
aiss's avatar
aiss committed
18
from unit.util import required_minimum_torch_version, required_maximum_torch_version
aiss's avatar
aiss committed
19

aiss's avatar
aiss committed
20
21
22
23
pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5),
                                reason='Megatron-LM package requires Pytorch version 1.5 or above')
pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version=1, minor_version=13),
                                reason='Megatron-LM package requires Pytorch version 1.13 or below')
aiss's avatar
aiss committed
24
25
26
27
28
29
30
31
32
33
34
35
36


def get_deepspeed_model(model):
    ds_config_dict = {
        "train_micro_batch_size_per_gpu": 1,
        "optimizer": {
            "type": "Lamb",
            "params": {
                "lr": 0.00015
            }
        },
    }

aiss's avatar
aiss committed
37
    model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=ds_config_dict)
aiss's avatar
aiss committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    return model.to(get_accelerator().device_name())


def get_topology(mp, pp, world_size):
    assert world_size % (pp * mp) == 0
    dp = world_size // (pp * mp)

    from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
    topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

    return topo


class ConfigurablePP(DistributedTest):
aiss's avatar
aiss committed
52

aiss's avatar
aiss committed
53
54
55
56
57
58
59
60
61
62
    @pytest.fixture(autouse=True)
    def reset_random(self, seed=1234):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        get_accelerator().manual_seed_all(seed)

    @pytest.fixture
    def inputs(self, bs=1, seq_len=1, hidden_size=128):
        hidden_states = torch.randn(bs, seq_len, hidden_size)
aiss's avatar
aiss committed
63
        attention_mask = torch.randint(low=0, high=2, size=(bs, seq_len), dtype=torch.bool)
aiss's avatar
aiss committed
64
65
66
67
68
69
70
71
        return (hidden_states, attention_mask)


class TestConfigurablePP(ConfigurablePP):
    mp_size = 2
    pp_size = 2
    world_size = 4  # mp_size * pp_size

aiss's avatar
aiss committed
72
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
aiss's avatar
aiss committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    def test_pp_basic(self, inputs, tmpdir):
        # basic test case, mp_size=2, pp_size=2, verify ckpt saving/loading.
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }
        mp_size = self.mp_size
        pp_size = self.pp_size
        world_size = self.world_size

        topo = get_topology(mp_size, pp_size, world_size)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_size,
                                        mp_size=mp_size,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        tag = 'pp_basic'
        state_dict = {}
        state_dict['checkpoint_version'] = get_megatron_version()
        model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)

        if model.is_first_stage() or model.is_last_stage():
            loader = RepeatingLoader([(inputs[0], 0)])
            data_iter = iter(loader)
        else:
            data_iter = None

aiss's avatar
aiss committed
104
        baseline = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)
aiss's avatar
aiss committed
105
106

        dist.barrier()
aiss's avatar
aiss committed
107
        model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
aiss's avatar
aiss committed
108
109
        dist.barrier()

aiss's avatar
aiss committed
110
        test = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)
aiss's avatar
aiss committed
111
112
113
114
115
116
117

        if test is not None:
            assert len(baseline) == len(test)
            # Compare outputs of each microbatch
            for mb in range(len(baseline)):
                for b, t in zip(baseline[mb], test[mb]):
                    if b.is_floating_point():  # don't compare masks
aiss's avatar
aiss committed
118
119
120
                        assert torch.allclose(
                            b, t,
                            atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"
aiss's avatar
aiss committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135


# Fixture for defining the checkpoint path since all tests in
# TestConfigurableResizePP will use the same tmpdir
@pytest.fixture
def checkpoint_tag(mp_size, pp_size, mp_resize, pp_resize):
    return f"{mp_size}-{pp_size}-{mp_resize}-{pp_resize}"


# Base class for creating / saving model output for baseline models. This is
# not meant to be used directly as a fixture to any classes
class _baseline(DistributedFixture):
    world_size = None

    def run(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size):
aiss's avatar
aiss committed
136
137
        assert int(os.environ["WORLD_SIZE"]) == (pp_size *
                                                 mp_size), "world size does not match provided pp_size and mp_size"
aiss's avatar
aiss committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        topo = get_topology(mp_size, pp_size, mp_size * pp_size)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_size,
                                        mp_size=mp_size,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        with torch.no_grad():
            inputs = [x.to(get_accelerator().device_name()) for x in inputs]
            if model.is_first_stage() or model.is_last_stage():
                loader = RepeatingLoader([(inputs[0], 0)])
                data_iter = iter(loader)
            else:
                data_iter = None

aiss's avatar
aiss committed
161
            baseline = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)
aiss's avatar
aiss committed
162
163
164
165
166
167
168
169
170
171
172

            if baseline is not None:
                # baseline should be [[hidden, True]]]
                assert len(baseline) == 1
                assert len(baseline[0]) == 1
                assert torch.is_tensor(baseline[0][0])
                save_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
                torch.save(baseline[0][0].cpu(), save_path)

            state_dict = {}
            state_dict['checkpoint_version'] = get_megatron_version()
aiss's avatar
aiss committed
173
            model.save_checkpoint(class_tmpdir, tag=checkpoint_tag, client_state=state_dict)
aiss's avatar
aiss committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191


# This may look odd, but there is a limitation with DistributedFixture that
# doesn't allow us to reuse a fixture with different worldsizes. This could be
# implemented in conftest.py::pytest_fixture_setup and common.py::DistributedFixture
class baseline_ws1(_baseline):
    world_size = 1


class baseline_ws2(_baseline):
    world_size = 2


class baseline_ws4(_baseline):
    world_size = 4


class TestConfigurableResizePP(ConfigurablePP):
aiss's avatar
aiss committed
192
193

    def _test(self, inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize):
aiss's avatar
aiss committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        args_defaults = {
            'num_layers': 8,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        topo = get_topology(mp_resize, pp_resize, mp_resize * pp_resize)
        gpt2_pipe_model = GPT2ModelPipe(num_layers=8,
                                        num_stages=pp_resize,
                                        mp_size=mp_resize,
                                        args_others=args_defaults,
                                        topo=topo)
        model = get_deepspeed_model(gpt2_pipe_model)

        with torch.no_grad():
            model.load_checkpoint(class_tmpdir,
                                  tag=checkpoint_tag,
                                  load_optimizer_states=False,
                                  load_lr_scheduler_states=False)
            inputs = [x.to(get_accelerator().device_name()) for x in inputs]
            if model.is_first_stage() or model.is_last_stage():
                loader = RepeatingLoader([(inputs[0], 0)])
                data_iter = iter(loader)
            else:
                data_iter = None

aiss's avatar
aiss committed
221
            test = model.eval_batch(data_iter=data_iter, compute_loss=False, reduce_output=None)
aiss's avatar
aiss committed
222
223
224
225
226
227
228
229
230

            if test is not None:
                # test should be [[hidden, True]]]
                assert len(test) == 1
                assert len(test[0]) == 1
                assert torch.is_tensor(test[0][0])
                test = test[0][0].cpu()
                load_path = os.path.join(class_tmpdir, f"output-{checkpoint_tag}.pt")
                baseline = torch.load(load_path)
aiss's avatar
aiss committed
231
232
233
                assert torch.allclose(
                    baseline, test,
                    atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"
aiss's avatar
aiss committed
234
235
236
237

    # These tests are divided by baseline model worldsize and test model worldsize
    @pytest.mark.world_size(1)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 1)])
aiss's avatar
aiss committed
238
239
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_2to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize,
aiss's avatar
aiss committed
240
                             pp_resize):
aiss's avatar
aiss committed
241
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)
aiss's avatar
aiss committed
242
243
244

    @pytest.mark.world_size(1)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 1, 1)])
aiss's avatar
aiss committed
245
246
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_4to1(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize,
aiss's avatar
aiss committed
247
                             pp_resize):
aiss's avatar
aiss committed
248
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)
aiss's avatar
aiss committed
249
250
251

    @pytest.mark.world_size(2)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(2, 2, 2, 1)])
aiss's avatar
aiss committed
252
253
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_4to2(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws4, mp_size, pp_size, mp_resize,
aiss's avatar
aiss committed
254
                             pp_resize):
aiss's avatar
aiss committed
255
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)
aiss's avatar
aiss committed
256
257
258

    @pytest.mark.world_size(4)
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 1, 2, 2)])
aiss's avatar
aiss committed
259
260
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_1to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws1, mp_size, pp_size, mp_resize,
aiss's avatar
aiss committed
261
                             pp_resize):
aiss's avatar
aiss committed
262
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)
aiss's avatar
aiss committed
263
264

    @pytest.mark.world_size(4)
aiss's avatar
aiss committed
265
266
267
    @pytest.mark.parametrize("mp_size, pp_size, mp_resize, pp_resize", [(1, 2, 1, 4), (2, 1, 2, 2)])
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
    def test_world_size_2to4(self, inputs, class_tmpdir, checkpoint_tag, baseline_ws2, mp_size, pp_size, mp_resize,
aiss's avatar
aiss committed
268
                             pp_resize):
aiss's avatar
aiss committed
269
        self._test(inputs, class_tmpdir, checkpoint_tag, mp_size, pp_size, mp_resize, pp_resize)