test_ds_initialize.py 10.4 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

import pytest
from typing import Callable
import torch
from torch.optim import Optimizer, Adam, AdamW
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR

from unit.simple_model import SimpleModel, random_dataloader
from unit.common import DistributedTest
from unit.util import required_torch_version, bf16_required_version_check, required_amp_check

import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR
from deepspeed.runtime.config import ADAM_OPTIMIZER
from deepspeed.runtime.utils import see_memory_usage


@pytest.mark.parametrize('zero_stage', [0, 3])
class TestNoOptim(DistributedTest):
    world_size = 1

    def test(self, zero_stage):
        if zero_stage == 3 and not required_torch_version():
            pytest.skip("zero-3 param offload requires at least torch 1.8")

        ds_config = {
            'train_batch_size': self.world_size,
            'fp16': {
                'enabled': True
            },
            'zero_optimization': {
                "stage": zero_stage,
                "offload_param": {
                    "device": "cpu"
                }
            }
        }
        # 20B test
        #hidden_dim = 16 * 1024
        hidden_dim = 4

        with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=ds_config):
            model = SimpleModel(hidden_dim, nlayers=78)
        see_memory_usage('pre-init', force=True)
        model, _, _, _ = deepspeed.initialize(model=model, config=ds_config)
        see_memory_usage('post-init', force=True)
        data_loader = random_dataloader(model=model,
                                        total_samples=50,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=torch.half)
        for batch in data_loader:
            model(batch[0], batch[1])
        see_memory_usage('post-fwds', force=True)


@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
class TestClientOptimizer(DistributedTest):
    world_size = 1

    def test(self, optimizer_type):
aiss's avatar
aiss committed
67

aiss's avatar
aiss committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        def _optimizer_callable(params) -> Optimizer:
            return AdamW(params=params)

        hidden_dim = 10
        model = SimpleModel(hidden_dim)

        config_dict = {'train_batch_size': 1}
        if optimizer_type is None:
            client_optimizer = None
            config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
        elif optimizer_type is Optimizer:
            client_optimizer = Adam(model.parameters())
        else:
            client_optimizer = _optimizer_callable

        _, ds_optimizer, _, _ = deepspeed.initialize(config=config_dict,
aiss's avatar
aiss committed
84
85
86
                                                     model=model,
                                                     model_parameters=list(model.parameters()),
                                                     optimizer=client_optimizer)
aiss's avatar
aiss committed
87
88
89
90
91
92
93
94
95
96
97
98
99
        if client_optimizer is None:
            assert isinstance(ds_optimizer, FusedAdam)
        elif isinstance(client_optimizer, Optimizer):
            assert ds_optimizer == client_optimizer
        else:
            assert isinstance(ds_optimizer, AdamW)


@pytest.mark.parametrize('client_parameters', [True, False])
class TestConfigOptimizer(DistributedTest):
    world_size = 1

    def test(self, client_parameters):
aiss's avatar
aiss committed
100
        ds_config = {"train_batch_size": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}}
aiss's avatar
aiss committed
101
102
103
104
105
106
107
108
109

        hidden_dim = 10
        model = SimpleModel(hidden_dim)

        if client_parameters:
            model_parameters = list(model.parameters())
        else:
            model_parameters = None

aiss's avatar
aiss committed
110
        _, ds_optimizer, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model_parameters)
aiss's avatar
aiss committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        assert isinstance(ds_optimizer, FusedAdam)


@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'amp', None])
@pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32'])
class TestOptimizerImplementation(DistributedTest):
    world_size = 1

    def test(self, optimizer_extension, model_dtype, grad_accum_dtype):
        if optimizer_extension == 'zero1':
            zero_stage = 1
        elif optimizer_extension == 'zero2':
            zero_stage = 2
        else:
            zero_stage = 0
        amp = True if optimizer_extension == 'amp' else False
        fp16 = True if model_dtype == 'fp16' else False
        bf16 = True if model_dtype == 'bf16' else False
        # Skip checks
        if bf16 and not bf16_required_version_check():
            pytest.skip(
                "DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
            )
        if amp and not required_amp_check():
            pytest.skip("Amp is not installed can't run amp check")
        # Config declaration
        ds_config = {
            "train_batch_size": 1,
            'fp16': {
                'enabled': fp16
            },
            'bf16': {
                'enabled': bf16
            },
            'amp': {
                'enabled': amp
            },
            'zero_optimization': {
                "stage": zero_stage
            },
            "data_types": {
                "grad_accum_dtype": grad_accum_dtype
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 0.001
                }
            }
        }

        key = (optimizer_extension, model_dtype, grad_accum_dtype)

        # Enumerate supported configurations
        is_supported = {}
        # ZeRO 1 Wrapper
        is_supported[('zero1', 'fp16', None)] = True
        is_supported[('zero1', 'fp16', 'fp16')] = True
        is_supported[('zero1', 'bf16', None)] = True
        is_supported[('zero1', 'bf16', 'bf16')] = True
        is_supported[('zero1', 'bf16', 'fp32')] = True
        is_supported[('zero1', 'fp32', None)] = True
        is_supported[('zero1', 'fp32', 'fp32')] = True
        # ZeRO 2 Wrapper
        is_supported[('zero2', 'fp16', None)] = True
        is_supported[('zero2', 'fp16', 'fp16')] = True
        is_supported[('zero2', 'bf16', None)] = True
        is_supported[('zero2', 'bf16', 'bf16')] = True
        is_supported[('zero2', 'fp32', None)] = True
        is_supported[('zero2', 'fp32', 'fp32')] = True
        # Amp Wrapper
        is_supported[('amp', 'fp32', None)] = True
        is_supported[('amp', 'fp32', 'fp32')] = True
        # FP16 Wrapper
        is_supported[(None, 'fp16', None)] = True
        is_supported[(None, 'fp16', 'fp16')] = True
        # BF16 Wrapper
        is_supported[(None, 'bf16', 'fp32')] = True
        is_supported[(None, 'bf16', None)] = True
        # No Wrapper
        is_supported[(None, 'fp32', None)] = True
        is_supported[(None, 'fp32', 'fp32')] = True

        hidden_dim = 10
        model = SimpleModel(hidden_dim)
        model_parameters = list(model.parameters())

        if key in is_supported:
            _, ds_optimizer, _, _ = deepspeed.initialize(config=ds_config,
aiss's avatar
aiss committed
202
203
                                                         model=model,
                                                         model_parameters=model_parameters)
aiss's avatar
aiss committed
204
205
206
207
            assert True
        else:
            with pytest.raises(NotImplementedError):
                _, ds_optimizer, _, _ = deepspeed.initialize(config=ds_config,
aiss's avatar
aiss committed
208
209
                                                             model=model,
                                                             model_parameters=model_parameters)
aiss's avatar
aiss committed
210
211
212
213
214
215
216
217


@pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable])
@pytest.mark.parametrize("optimizer_type", [None, Optimizer, Callable])
class TestClientLrScheduler(DistributedTest):
    world_size = 1

    def test(self, scheduler_type, optimizer_type):
aiss's avatar
aiss committed
218

aiss's avatar
aiss committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        def _my_lambda(epoch):
            return epoch // 10

        def _optimizer_callable(params) -> Optimizer:
            return torch.optim.AdamW(params=params)

        def _lr_scheduler_callable(optimizer) -> _LRScheduler:
            return LambdaLR(optimizer, _my_lambda)

        hidden_dim = 10
        model = SimpleModel(hidden_dim)

        config_dict = {'train_batch_size': 1}

        client_optimizer = None
        client_scheduler = None

        if optimizer_type is None:
            config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
        elif optimizer_type is Optimizer:
            client_optimizer = torch.optim.Adam(model.parameters())
        else:
            client_optimizer = _optimizer_callable

        if scheduler_type is None:
            config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}}
        elif scheduler_type == _LRScheduler:
            if isinstance(client_optimizer, Optimizer):
                client_scheduler = LambdaLR(client_optimizer, _my_lambda)
            else:
                # Verify invalid combination is correctly handled
aiss's avatar
aiss committed
250
                client_scheduler = LambdaLR(torch.optim.Adam(model.parameters()), _my_lambda)
aiss's avatar
aiss committed
251
252
253
        else:
            client_scheduler = _lr_scheduler_callable

aiss's avatar
aiss committed
254
        if isinstance(client_scheduler, _LRScheduler) and not isinstance(client_optimizer, Optimizer):
aiss's avatar
aiss committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
            with pytest.raises(AssertionError):
                _, _, _, _ = deepspeed.initialize(config=config_dict,
                                                  model=model,
                                                  model_parameters=list(model.parameters()),
                                                  optimizer=client_optimizer,
                                                  lr_scheduler=client_scheduler)
        else:
            _, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict,
                                                            model=model,
                                                            model_parameters=list(model.parameters()),
                                                            optimizer=client_optimizer,
                                                            lr_scheduler=client_scheduler)
            if client_scheduler is None:
                assert isinstance(ds_lr_scheduler, WarmupLR)
            elif isinstance(client_scheduler, _LRScheduler):
                assert ds_lr_scheduler == client_scheduler
            else:
                assert isinstance(ds_lr_scheduler, LambdaLR)