test_configurable_parallel_mp.py 6.64 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15

import os
import torch
import deepspeed
import pytest
import random
import numpy as np
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator
from unit.common import DistributedTest, DistributedFixture
from unit.megatron_model import get_gpt2_model, get_megatron_version
aiss's avatar
aiss committed
16
from unit.util import required_minimum_torch_version, required_maximum_torch_version
aiss's avatar
aiss committed
17

aiss's avatar
aiss committed
18
19
20
21
pytestmark = pytest.mark.skipif(not required_minimum_torch_version(major_version=1, minor_version=5),
                                reason='Megatron-LM package requires Pytorch version 1.5 or above')
pytestmark = pytest.mark.skipif(not required_maximum_torch_version(major_version=1, minor_version=13),
                                reason='Megatron-LM package requires Pytorch version 1.13 or below')
aiss's avatar
aiss committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35


def get_deepspeed_model(model):
    ds_config_dict = {
        "train_micro_batch_size_per_gpu": 1,
        "optimizer": {
            "type": "Lamb",
            "params": {
                "lr": 0.00015
            }
        },
    }

    from megatron import mpu
aiss's avatar
aiss committed
36
37
38
39
    model, _, _, _ = deepspeed.initialize(model=model,
                                          mpu=mpu,
                                          model_parameters=model.parameters(),
                                          config=ds_config_dict)
aiss's avatar
aiss committed
40
41
42
43
    return model


class ConfigurableMP(DistributedTest):
aiss's avatar
aiss committed
44

aiss's avatar
aiss committed
45
46
47
48
49
50
51
52
53
54
55
    @pytest.fixture(autouse=True)
    def reset_random(self, seed=1234):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        get_accelerator().manual_seed_all(seed)

    @pytest.fixture
    def inputs(self, bs=1, seq_len=20):
        input_ids = torch.randint(low=0, high=1000, size=(bs, seq_len))
        position_ids = torch.randint(low=0, high=2, size=(bs, seq_len))
aiss's avatar
aiss committed
56
        attention_mask = torch.randint(low=0, high=2, size=(bs, seq_len), dtype=torch.bool)
aiss's avatar
aiss committed
57
58
59
60
        return [input_ids, position_ids, attention_mask]


class TestConfigurableMP(ConfigurableMP):
aiss's avatar
aiss committed
61

aiss's avatar
aiss committed
62
    @pytest.mark.world_size(1)
aiss's avatar
aiss committed
63
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
aiss's avatar
aiss committed
64
65
66
67
68
69
70
71
72
73
74
75
76
    def test_gpt2_basic(self, tmpdir, inputs):
        args_defaults = {
            'num_layers': 2,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        model = get_gpt2_model(args_defaults)
        model = get_deepspeed_model(model)

        model.eval()
        device_name = get_accelerator().device_name()
aiss's avatar
aiss committed
77
        baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
aiss's avatar
aiss committed
78
79
80
81
82
83

        tag = 'mp_1'
        state_dict = {}
        state_dict['checkpoint_version'] = get_megatron_version()
        model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
        dist.barrier()
aiss's avatar
aiss committed
84
        model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
aiss's avatar
aiss committed
85
86

        test = model(inputs[0], inputs[1], inputs[2])
aiss's avatar
aiss committed
87
88
        assert torch.allclose(baseline, test,
                              atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"
aiss's avatar
aiss committed
89
90

    @pytest.mark.world_size(2)
aiss's avatar
aiss committed
91
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
aiss's avatar
aiss committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def test_gpt2_mp2_no_resize(self, tmpdir, inputs):
        args_defaults = {
            'num_layers': 2,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        model = get_gpt2_model(args_defaults, mp_size=2)
        model = get_deepspeed_model(model)

        model.eval()

        device_name = get_accelerator().device_name()
aiss's avatar
aiss committed
106
        baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
aiss's avatar
aiss committed
107
108
109
110
111
112

        tag = 'mp_2'
        state_dict = {}
        state_dict['checkpoint_version'] = get_megatron_version()
        model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
        dist.barrier()
aiss's avatar
aiss committed
113
        model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
aiss's avatar
aiss committed
114
115

        device_name = get_accelerator().device_name()
aiss's avatar
aiss committed
116
117
118
        test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
        assert torch.allclose(baseline, test, rtol=1.0,
                              atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"
aiss's avatar
aiss committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


# This fixture provides the baseline model with mp=2 to TestConfigurableMPResize
class baseline_mp2(DistributedFixture):
    world_size = 2

    def run(self, inputs, class_tmpdir):
        args_defaults = {
            'num_layers': 2,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        model = get_gpt2_model(args_defaults, mp_size=self.world_size)
        model = get_deepspeed_model(model)

        model.eval()

        with torch.no_grad():
            device_name = get_accelerator().device_name()
aiss's avatar
aiss committed
140
            baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
aiss's avatar
aiss committed
141
142
143
144
145
146
147
148
149
150
151
152
            if dist.get_rank() == 0:
                save_path = os.path.join(class_tmpdir, "output.pt")
                torch.save(baseline.cpu(), save_path)

            state_dict = {}
            state_dict['checkpoint_version'] = get_megatron_version()
            model.save_checkpoint(class_tmpdir, client_state=state_dict)


class TestConfigurableResizeMP(ConfigurableMP):
    world_size = [1, 4]

aiss's avatar
aiss committed
153
    @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
aiss's avatar
aiss committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def test(self, baseline_mp2, inputs, class_tmpdir):
        args_defaults = {
            'num_layers': 2,
            'hidden_size': 128,
            'num_attention_heads': 8,
            'max_position_embeddings': 128,
        }

        world_size = os.environ["WORLD_SIZE"]
        model = get_gpt2_model(args_defaults, mp_size=world_size)
        model = get_deepspeed_model(model)

        model.eval()

        with torch.no_grad():
aiss's avatar
aiss committed
169
            model.load_checkpoint(class_tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False)
aiss's avatar
aiss committed
170
            device_name = get_accelerator().device_name()
aiss's avatar
aiss committed
171
            test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
aiss's avatar
aiss committed
172
173
174
175
            if dist.get_rank() == 0:
                load_path = os.path.join(class_tmpdir, "output.pt")
                baseline = torch.load(load_path)
                test = test.cpu()
aiss's avatar
aiss committed
176
177
178
                assert torch.allclose(
                    baseline, test,
                    atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"