test_zero.py 51.5 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6

import math
aiss's avatar
aiss committed
7
8
from collections import namedtuple
from typing import Dict, List, NamedTuple, Set, Tuple
aiss's avatar
aiss committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import pytest
import deepspeed.comm as dist
import torch
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn.modules.container import ModuleList
from torch.nn.modules.loss import L1Loss
from torch.nn.parameter import Parameter

from unit.common import DistributedTest
from unit.simple_model import SimpleModel, random_dataloader

import deepspeed
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
aiss's avatar
aiss committed
25
from deepspeed.runtime.zero.utils import ZeRORuntimeException
aiss's avatar
aiss committed
26
27
28
29
from deepspeed.accelerator import get_accelerator


def run_unbalanced_gradients(model, data_loader):
aiss's avatar
aiss committed
30

aiss's avatar
aiss committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    def drop_some_gradients(model, iter):
        odd_iteration = iter % 2
        for i, p in enumerate(model.parameters()):
            p.requires_grad = (i % 2) == odd_iteration

    def enable_grads(model):
        for p in model.parameters():
            p.requires_grad = True

    for i, batch in enumerate(data_loader):
        drop_some_gradients(model, i + 1)
        loss = model(batch[0], batch[1])
        model.backward(loss)
        model.step()
        enable_grads(model)


def dump_state_dict(model):
    if dist.get_rank() == 0:
        print("state_dict:")
        for name, param in model.named_parameters():
            print(f"{name} {param.data}")


@pytest.mark.parametrize('zero_stage', [1, 2, 3])
class TestZeroUnbalancedGradients(DistributedTest):
    world_size = 1

    def test(self, zero_stage):
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }
        hidden_dim = 4

        model = SimpleModel(hidden_dim=hidden_dim)
aiss's avatar
aiss committed
81
82
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        run_unbalanced_gradients(model, data_loader)


# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
class TestZero3RepeatForwardLoop(DistributedTest):
    world_size = 1

    def test(self, zero_stage=3):
        # force all params to be partitioned by forcing threshold=0
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage,
                "stage3_param_persistence_threshold": 0
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }
        hidden_dim = 4

        class AlbertLikeModel(torch.nn.Module):
aiss's avatar
aiss committed
115

aiss's avatar
aiss committed
116
117
118
119
120
121
122
123
124
125
126
127
128
            def __init__(self, hidden_dim):
                super().__init__()
                self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                # run the same layer multiple times in a loop - to test a stack of forwards, followed by a stack of backwards
                hidden = x
                for i in range(3):
                    hidden = hidden + self.linear(hidden)
                return self.cross_entropy_loss(hidden, y)

        model = AlbertLikeModel(hidden_dim=hidden_dim)
aiss's avatar
aiss committed
129
130
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
        data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
131
132
133
134
135
136
137
138
139
140

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()


# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227
# also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372
@pytest.mark.parametrize('zero_stage', [2, 3])
aiss's avatar
aiss committed
141
@pytest.mark.parametrize('freeze_params', [True, False])
aiss's avatar
aiss committed
142
143
144
class TestZeroToFP32(DistributedTest):
    world_size = 2

aiss's avatar
aiss committed
145
    def test_1_param_group(self, tmpdir, zero_stage, freeze_params):
aiss's avatar
aiss committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        # XXX: ideally refactor with the 2_param_group test as 75% is the same
        # force all params to be partitioned by forcing threshold=0
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage,
                "stage3_param_persistence_threshold": 0
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }

        class MyModel(torch.nn.Module):
aiss's avatar
aiss committed
169
170

            def __init__(self, hidden_dim, n_layers, freeze_params):
aiss's avatar
aiss committed
171
172
173
174
                super().__init__()
                # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that
                # the number of total elements is uneven:
                # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total
aiss's avatar
aiss committed
175
                self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers))
aiss's avatar
aiss committed
176
177
178
179
                # (2) the following adds 4+1=5 elements
                self.classifier = torch.nn.Linear(4, 1)
                # total 48+5=53 (uneven as desired) elements
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
aiss's avatar
aiss committed
180
181
182
                if freeze_params:
                    self.ll[0].weight.requires_grad = False
                    self.ll[0].bias.requires_grad = False
aiss's avatar
aiss committed
183
184
185
186
187
188
189
190
191
192
193
194

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        hidden_dim = 3  # do not change

        world_size = dist.get_world_size()
        # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2
        n_layers = world_size * 2
aiss's avatar
aiss committed
195
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params)
aiss's avatar
aiss committed
196

aiss's avatar
aiss committed
197
198
199
200
201
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
        # Flush zero stage 3 cache
        model.empty_partition_cache()

        data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
202
203
204
205
206
207

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

aiss's avatar
aiss committed
208
        model.empty_partition_cache()
aiss's avatar
aiss committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            if zero_stage == 3:
                with deepspeed.zero.GatheredParameters(param, modifier_rank=None):
                    orig_state_dict[name] = param.detach().cpu()
            else:
                orig_state_dict[name] = param.detach().cpu()

        if zero_stage == 3:
aiss's avatar
aiss committed
223
            with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None):
aiss's avatar
aiss committed
224
225
226
227
228
229
230
231
232
233
234
                fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
                fp32_state_dict = fp32_model.state_dict()
        else:
            fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
            fp32_state_dict = fp32_model.state_dict()

        #dump_state_dict(fp32_model)

        if dist.get_rank() == 0:
            for name in orig_state_dict.keys():
                # float() workaround for torch<1.6
aiss's avatar
aiss committed
235
                assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
aiss's avatar
aiss committed
236

aiss's avatar
aiss committed
237
    def test_2_param_groups(self, tmpdir, zero_stage, freeze_params):
aiss's avatar
aiss committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        # TODO:
        # - need to test with multiple param groups
        # force all params to be partitioned by forcing threshold=0
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_allow_untested_optimizer": 1,
            "zero_optimization": {
                "stage": zero_stage,
                "stage3_param_persistence_threshold": 0
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }

        class MyModel(torch.nn.Module):
aiss's avatar
aiss committed
263
264

            def __init__(self, hidden_dim, n_layers, freeze_params):
aiss's avatar
aiss committed
265
                super().__init__()
aiss's avatar
aiss committed
266
                self.ll = torch.nn.ModuleList(torch.nn.Linear(hidden_dim, hidden_dim) for i in range(n_layers))
aiss's avatar
aiss committed
267
                self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
aiss's avatar
aiss committed
268
269
270
                if freeze_params:
                    self.ll[0].weight.requires_grad = False
                    self.ll[0].bias.requires_grad = False
aiss's avatar
aiss committed
271
272
273
274
275
276
277
278
279
280
281

            def forward(self, x, y):
                hidden = x
                for l in self.ll:
                    hidden = l(hidden)
                return self.cross_entropy_loss(hidden, y)

        hidden_dim = 3

        world_size = dist.get_world_size()
        n_layers = world_size * 2
aiss's avatar
aiss committed
282
        model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers, freeze_params=freeze_params)
aiss's avatar
aiss committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

        optim_groups = [
            {
                "params": [l.weight for l in model.ll],
                "weight_decay": 0.01,
            },
            {
                "params": [l.bias for l in model.ll],
                "weight_decay": 0.0
            },
        ]
        optim = torch.optim.SGD(optim_groups, lr=0.1)

        model, _, _, _ = deepspeed.initialize(model=model,
                                              model_parameters=model.parameters(),
                                              optimizer=optim,
aiss's avatar
aiss committed
299
300
301
302
                                              config=config_dict)
        model.empty_partition_cache()

        data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
303
304
305
306
307
308

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

aiss's avatar
aiss committed
309
        model.empty_partition_cache()
aiss's avatar
aiss committed
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        model.save_checkpoint(tmpdir)

        # make sure all sides saved it
        dist.barrier()

        #dump_state_dict(model)

        orig_state_dict = {}
        for name, param in model.module.named_parameters():
            if zero_stage == 3:
                with deepspeed.zero.GatheredParameters(param, modifier_rank=None):
                    orig_state_dict[name] = param.detach().cpu()
            else:
                orig_state_dict[name] = param.detach().cpu()

        if zero_stage == 3:
aiss's avatar
aiss committed
326
            with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None):
aiss's avatar
aiss committed
327
328
329
330
331
332
333
334
335
336
337
                fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
                fp32_state_dict = fp32_model.state_dict()
        else:
            fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
            fp32_state_dict = fp32_model.state_dict()

        #dump_state_dict(fp32_model)

        if dist.get_rank() == 0:
            for name in orig_state_dict.keys():
                # float() workaround for torch<1.6
aiss's avatar
aiss committed
338
                assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float())
aiss's avatar
aiss committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368


@pytest.mark.parametrize("allgather_bucket_size", [1000, 1001])
class TestIncorectAllgatherBucketSize(DistributedTest):
    world_size = 1

    def test(self, allgather_bucket_size, zero_stage=2):
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage,
                "allgather_bucket_size": allgather_bucket_size
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }
        hidden_dim = 4

        model = SimpleModel(hidden_dim=hidden_dim)
        if allgather_bucket_size % 2 == 0:
aiss's avatar
aiss committed
369
            model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
aiss's avatar
aiss committed
370
371
372
        else:
            with pytest.raises(AssertionError) as assertinfo:
                model, _, _, _ = deepspeed.initialize(config=config_dict,
aiss's avatar
aiss committed
373
374
375
                                                      model=model,
                                                      model_parameters=model.parameters())
            assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str(assertinfo)
aiss's avatar
aiss committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402


class TestPartitionNcclAlignment(DistributedTest):
    world_size = 4

    def test(self, zero_stage=2):
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }
        hidden_dim = 4

        model = SimpleModel(hidden_dim=hidden_dim)
aiss's avatar
aiss committed
403
        model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
aiss's avatar
aiss committed
404
405
406
407
408
409
410
411

        # get nccl all-gather send buffers alignment factor
        nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor

        parallel_partitioned_bit16_groups = model.optimizer.parallel_partitioned_bit16_groups if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups
        for data_parallel_partitions in parallel_partitioned_bit16_groups:
            for partition_id, partitioned_data in enumerate(data_parallel_partitions):
                # verify that data partition start locations are 4-byte aligned
aiss's avatar
aiss committed
412
                assert (partitioned_data.data_ptr() % (2 * nccl_start_alignment_factor) == 0)
aiss's avatar
aiss committed
413
414


aiss's avatar
aiss committed
415
416
def _ds_initialize_for_param_partitioning_testing(model: Module, cfg: dict) -> DeepSpeedEngine:
    ds_engine, _, _, _ = deepspeed.initialize(config=cfg, model=model, model_parameters=model.parameters())
aiss's avatar
aiss committed
417
418
419
420

    return ds_engine


aiss's avatar
aiss committed
421
def _assert_partition_status(model: Module, valid_statuses: Set[ZeroParamStatus]) -> None:
aiss's avatar
aiss committed
422
423
424
425
426
427
428
429
430
431
    for _, param in model.named_parameters():
        assert param.ds_status in valid_statuses, param.ds_summary()


def _assert_fully_available(model: Module) -> None:
    for _, param in model.named_parameters():
        assert param.ds_status == ZeroParamStatus.AVAILABLE


class EltwiseMultiplicationModule(Module):
aiss's avatar
aiss committed
432

aiss's avatar
aiss committed
433
434
435
436
437
438
439
440
441
442
443
    def __init__(self, weight: Parameter) -> None:
        super().__init__()
        self.weight = weight

    def forward(self, x: Tensor) -> Tensor:
        _assert_fully_available(self)
        result = self.weight * x

        return result


aiss's avatar
aiss committed
444
class EltwiseMultiplicationTestNetwork_Dict(Module):
aiss's avatar
aiss committed
445
    """used for testing purposes"""
aiss's avatar
aiss committed
446

aiss's avatar
aiss committed
447
448
449
450
451
452
453
454
455
456
457
458
459
    def __init__(
        self,
        weight1: Parameter,
        weight2: Parameter,
        weight3: Parameter,
    ) -> None:
        super().__init__()
        self.__layer1 = EltwiseMultiplicationModule(weight1)
        self.__layer2 = EltwiseMultiplicationModule(weight2)
        self.__layer3 = EltwiseMultiplicationModule(weight3)

        self.loss = L1Loss(reduction="none")

aiss's avatar
aiss committed
460
461
462
463
    def forward(self, x: Tensor, y: Tensor, use_module_trace: bool, param_prefetching: bool) -> Dict[str, Tensor]:
        _assert_partition_status(self,
                                 {ZeroParamStatus.NOT_AVAILABLE, ZeroParamStatus.INFLIGHT, ZeroParamStatus.AVAILABLE}
                                 if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})
aiss's avatar
aiss committed
464
465

        pre_layer_expected_states = {
aiss's avatar
aiss committed
466
            ZeroParamStatus.INFLIGHT if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
aiss's avatar
aiss committed
467
468
469
470
            ZeroParamStatus.AVAILABLE,
        }

        post_layer_expected_states = {
aiss's avatar
aiss committed
471
            ZeroParamStatus.AVAILABLE if param_prefetching else ZeroParamStatus.NOT_AVAILABLE,
aiss's avatar
aiss committed
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        }

        _assert_partition_status(self.__layer1, pre_layer_expected_states)
        hidden1 = self.__layer1(x)
        _assert_partition_status(self.__layer1, post_layer_expected_states)

        _assert_partition_status(self.__layer2, pre_layer_expected_states)
        hidden2 = self.__layer2(hidden1)
        _assert_partition_status(self.__layer2, post_layer_expected_states)

        _assert_partition_status(self.__layer3, pre_layer_expected_states)
        y_hat = self.__layer3(hidden2)
        _assert_partition_status(self.__layer3, post_layer_expected_states)

        loss = self.loss(y_hat, y)

aiss's avatar
aiss committed
488
489
490
        _assert_partition_status(self,
                                 {ZeroParamStatus.NOT_AVAILABLE, ZeroParamStatus.INFLIGHT, ZeroParamStatus.AVAILABLE}
                                 if use_module_trace else {ZeroParamStatus.NOT_AVAILABLE})
aiss's avatar
aiss committed
491
492
493
494
495
496
497
498

        return {
            "hidden1": hidden1,
            "hidden2": hidden2,
            "y_hat": y_hat,
            "loss": loss,
        }

aiss's avatar
aiss committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    @staticmethod
    def to_dict(outputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
        return outputs


class EltwiseMultiplicationNamedTuple(NamedTuple):
    hidden1: Tensor
    hidden2: Tensor
    y_hat: Tensor
    loss: Tensor


class EltwiseMultiplicationTestNetwork_NamedTuple(EltwiseMultiplicationTestNetwork_Dict):

    def forward(self, *args, **kwargs) -> EltwiseMultiplicationNamedTuple:
        outputs_dicts = super().forward(*args, **kwargs)
        return EltwiseMultiplicationNamedTuple(hidden1=outputs_dicts['hidden1'],
                                               hidden2=outputs_dicts['hidden2'],
                                               y_hat=outputs_dicts['y_hat'],
                                               loss=outputs_dicts['loss'])

    @staticmethod
    def to_dict(outputs: EltwiseMultiplicationNamedTuple) -> Dict[str, Tensor]:
        return {
            "hidden1": outputs.hidden1,
            "hidden2": outputs.hidden2,
            "y_hat": outputs.y_hat,
            "loss": outputs.loss,
        }


EltwiseMultiplication_namedtuple = namedtuple('EltwiseMultiplication_namedtuple',
                                              ['hidden1', 'hidden2', 'y_hat', 'loss'])


class EltwiseMultiplicationTestNetwork_namedtuple(EltwiseMultiplicationTestNetwork_Dict):

    def forward(self, *args, **kwargs) -> EltwiseMultiplication_namedtuple:
        outputs_dicts = super().forward(*args, **kwargs)
        return EltwiseMultiplication_namedtuple(hidden1=outputs_dicts['hidden1'],
                                                hidden2=outputs_dicts['hidden2'],
                                                y_hat=outputs_dicts['y_hat'],
                                                loss=outputs_dicts['loss'])

    @staticmethod
    def to_dict(outputs: EltwiseMultiplicationNamedTuple) -> Dict[str, Tensor]:
        return {
            "hidden1": outputs.hidden1,
            "hidden2": outputs.hidden2,
            "y_hat": outputs.y_hat,
            "loss": outputs.loss,
        }


class EltwiseMultiplicationTestNetwork_Tuple(EltwiseMultiplicationTestNetwork_Dict):

    def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        outputs_dicts = super().forward(*args, **kwargs)
        return (outputs_dicts['hidden1'], outputs_dicts['hidden2'], outputs_dicts['y_hat'], outputs_dicts['loss'])

    @staticmethod
    def to_dict(outputs: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Dict[str, Tensor]:
        return {
            "hidden1": outputs[0],
            "hidden2": outputs[1],
            "y_hat": outputs[2],
            "loss": outputs[3],
        }


class EltwiseMultiplicationTestNetwork_List(EltwiseMultiplicationTestNetwork_Dict):

    def forward(self, *args, **kwargs) -> List[Tensor]:
        outputs_dicts = super().forward(*args, **kwargs)
        return [outputs_dicts['hidden1'], outputs_dicts['hidden2'], outputs_dicts['y_hat'], outputs_dicts['loss']]

    @staticmethod
    def to_dict(outputs: List[Tensor]) -> Dict[str, Tensor]:
        return {
            "hidden1": outputs[0],
            "hidden2": outputs[1],
            "y_hat": outputs[2],
            "loss": outputs[3],
        }

aiss's avatar
aiss committed
584
585
586
587
588
589
590

@pytest.mark.parametrize("param_persistence_threshold", [0, 10])
@pytest.mark.parametrize("fp16_enabled", [True, False])
@pytest.mark.parametrize("contiguous_gradients", [True, False])
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
@pytest.mark.parametrize("prefetching", [True, False])
aiss's avatar
aiss committed
591
592
593
594
595
@pytest.mark.parametrize("model_class", [
    EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple,
    EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple,
    EltwiseMultiplicationTestNetwork_List
])
aiss's avatar
aiss committed
596
597
598
599
600
601
602
603
604
605
606
class TestZero3ParamPartitioningBase(DistributedTest):
    world_size = 2

    def test(
        self,
        param_persistence_threshold: int,
        fp16_enabled: bool,
        contiguous_gradients: bool,
        offload_optimizer: bool,
        zero_grad: bool,
        prefetching: bool,
aiss's avatar
aiss committed
607
        model_class: EltwiseMultiplicationTestNetwork_Dict,
aiss's avatar
aiss committed
608
609
610
611
612
613
614
    ) -> None:
        if offload_optimizer and not contiguous_gradients:
            return

        m = 3
        n = 5
        weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)]
aiss's avatar
aiss committed
615
        model = model_class(*weights)
aiss's avatar
aiss committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        prefetch_bucket_size = sum([p.numel() for p in model.parameters(recurse=True)])
        cfg = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "stage3_param_persistence_threshold": param_persistence_threshold,
                "contiguous_gradients": contiguous_gradients,
                "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "fp16": {
                "enabled": fp16_enabled,
                "loss_scale": 1.,
            }
        }

        if offload_optimizer:
            cfg["zero_optimization"]["offload_optimizer"] = {
                "device": "cpu",
                "pin_memory": True,
            }

        ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg)
        for i, weight in enumerate(weights):
aiss's avatar
aiss committed
646
            weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, (i + 1) * (1 + dist.get_rank()))
aiss's avatar
aiss committed
647
648
649

        def create_tensor(vals, dtype: torch.dtype = None) -> Tensor:
            return torch.as_tensor(vals,
aiss's avatar
aiss committed
650
                                   dtype=dtype or (torch.float16 if fp16_enabled else torch.float32),
aiss's avatar
aiss committed
651
652
653
                                   device=ds_engine.device)

        expected_hidden1 = create_tensor([
aiss's avatar
aiss committed
654
655
656
            [1, 1, 1, 1, 1],
            [1, 1, 1, 2, 2],
            [2, 2, 2, 2, 2],
aiss's avatar
aiss committed
657
658
        ])
        expected_hidden2 = create_tensor([
aiss's avatar
aiss committed
659
660
661
            [2, 2, 2, 2, 2],
            [2, 2, 2, 8, 8],
            [8, 8, 8, 8, 8],
aiss's avatar
aiss committed
662
        ])
aiss's avatar
aiss committed
663
        expected_yhat = create_tensor([[6, 6, 6, 6, 6], [6, 6, 6, 48, 48], [48, 48, 48, 48, 48]])
aiss's avatar
aiss committed
664
        expected_loss = create_tensor([
aiss's avatar
aiss committed
665
666
667
            [5, 5, 5, 5, 5],
            [5, 5, 5, 47, 47],
            [47, 47, 47, 47, 47],
aiss's avatar
aiss committed
668
669
670
671
        ])

        for train_iter in range(3):
            activations = ds_engine(
aiss's avatar
aiss committed
672
673
                x=torch.ones((m, n), dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device),
                y=torch.ones((m, n), dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device),
aiss's avatar
aiss committed
674
675
676
                use_module_trace=train_iter > 0,
                param_prefetching=prefetching and train_iter > 0,
            )
aiss's avatar
aiss committed
677
678
            # for ease in testing convert outputs to dict.
            activations = model_class.to_dict(activations)
aiss's avatar
aiss committed
679
680
681
682
683
684
685
686
687
            assert torch.allclose(activations["hidden1"], expected_hidden1)
            assert torch.allclose(activations["hidden2"], expected_hidden2)
            assert torch.allclose(activations["y_hat"], expected_yhat)
            assert torch.allclose(activations["loss"], expected_loss)

            ds_engine.backward(activations["loss"].sum())

            # check the gradients
            grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions()
aiss's avatar
aiss committed
688
689
            assert set(grad_partitions.keys()) == {0
                                                   }, f"should have one parameter group but got {len(grad_partitions)}"
aiss's avatar
aiss committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
            assert set(grad_partitions[0].keys()) == {0, 1, 2}
            dloss_wrt_layer1 = grad_partitions[0][0]
            dloss_wrt_layer2 = grad_partitions[0][1]
            dloss_wrt_layer3 = grad_partitions[0][2]

            assert dloss_wrt_layer1.dtype == torch.float
            assert dloss_wrt_layer2.dtype == torch.float
            assert dloss_wrt_layer3.dtype == torch.float

            # layer1 = [..., 1, 2, ...]
            # layer2 = [..., 2, 4, ...]
            # layer3 = [..., 3, 6, ...]
            # dloss_wrt_layer3 = hidden2
            # dloss_wrt_layer2 = layer3 * hidden1
            # dloss_wrt_layer1 = layer3 * layer2 * x

            grad_multiplier = 1 if zero_grad else (train_iter + 1)
            if dist.get_rank() == 0:
aiss's avatar
aiss committed
708
709
710
711
712
713
                assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([2] * 8, torch.float))
                assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([3 * 1] * 8, torch.float))
                assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([3 * 2 * 1] * 8, torch.float))
aiss's avatar
aiss committed
714
715
716
            elif dist.get_rank() == 1:
                # parameters dont split evenly across ranks so rank 1 has a zero-padded
                # partition
aiss's avatar
aiss committed
717
718
719
720
721
722
                assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([8] * 7) + [0], torch.float))
                assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([6 * 2] * 7) + [0], torch.float))
                assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], torch.float))
aiss's avatar
aiss committed
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            else:
                raise RuntimeError("test has world size of two")

            if zero_grad:
                ds_engine.optimizer.zero_grad()

        # TODO. add testing for this - for now we just call it to make sure it
        # doesn't throw
        ds_engine.optimizer.step()
        # taking an optimizer step invalidates all parameters, make sure everything
        # has been partitioned afterwards
        _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})
        assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0)


@pytest.mark.parametrize("init_context_manager", [True, False])
class TestZero3ParamPartitioningLargeParam(DistributedTest):
    world_size = 4

    def test(self, init_context_manager: bool, param_sz: int = 8100) -> None:
aiss's avatar
aiss committed
743

aiss's avatar
aiss committed
744
        class LargeParamModel(Module):
aiss's avatar
aiss committed
745

aiss's avatar
aiss committed
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            def __init__(self):
                super().__init__()
                self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32))

                # only do weight initialization on root rank to
                # make sure we are broadcasting correctly from rank 0
                if dist.get_rank() == 0:
                    partition_sz = math.ceil(self.param.numel() / dist.get_world_size())
                    offset = 0
                    for rank in range(dist.get_world_size()):
                        with torch.no_grad():
                            self.param[offset:offset + partition_sz].fill_(rank)
                        offset += partition_sz

            def forward(self, x: Tensor) -> Tensor:
                return x * self.param

        ds_config = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "contiguous_gradients": True,
                "overlap_comm": True,
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 1.,
            }
        }
aiss's avatar
aiss committed
782
        with deepspeed.zero.Init(mem_efficient_linear=False, enabled=init_context_manager):
aiss's avatar
aiss committed
783
784
785
786
            model = LargeParamModel()
        ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config)

        for train_iter in range(3):  # test multiple iterations to cover prefetching
aiss's avatar
aiss committed
787
            activation: Tensor = ds_engine(torch.ones(param_sz, dtype=torch.float16, device=ds_engine.device))
aiss's avatar
aiss committed
788
789
790

            partition_sz = math.ceil(param_sz / self.world_size)
            for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)):
aiss's avatar
aiss committed
791
792
                activation_from_partition = activation[start_idx:start_idx + partition_sz]
                assert torch.allclose(activation_from_partition, torch.full_like(activation_from_partition, rank_idx))
aiss's avatar
aiss committed
793
794
795
796
797
798
799

            ds_engine.backward(activation.sum())
            ds_engine.allreduce_gradients()

            avgd_gradients = ds_engine.optimizer.averaged_gradients
            assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group"
            weight_gradient, = avgd_gradients[0]
aiss's avatar
aiss committed
800
            expected_weight_gradient = (train_iter + 1) * torch.full_like(weight_gradient, 1)
aiss's avatar
aiss committed
801
802
803
804
805
806
807
808
809
810
811

            assert torch.allclose(weight_gradient, expected_weight_gradient)


@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000])
@pytest.mark.parametrize("n_layers", [100, 1_000])
@pytest.mark.parametrize("init_context_manager", [True, False])
class TestZero3ParamPartitioningManyParams(DistributedTest):
    world_size = 4

    def test(self, param_sz: int, n_layers: int, init_context_manager: bool) -> None:
aiss's avatar
aiss committed
812

aiss's avatar
aiss committed
813
        class ManyParamModel(Module):
aiss's avatar
aiss committed
814

aiss's avatar
aiss committed
815
816
817
818
            def __init__(self) -> None:
                super().__init__()

                self.modulelist = ModuleList(
aiss's avatar
aiss committed
819
                    EltwiseMultiplicationModule(weight=Parameter(torch.empty((param_sz, ), dtype=torch.float32)))
aiss's avatar
aiss committed
820
821
822
                    for _ in range(n_layers))

                for layer_num, module in enumerate(self.modulelist):
aiss's avatar
aiss committed
823
                    with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
aiss's avatar
aiss committed
824
825
826
827
828
                        param: Parameter = module.weight
                        partition_sz = math.ceil(param.numel() / dist.get_world_size())
                        offset = 0
                        for rank in range(dist.get_world_size()):
                            with torch.no_grad():
aiss's avatar
aiss committed
829
                                param[offset:offset + partition_sz].fill_(2 * layer_num * rank)
aiss's avatar
aiss committed
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
                            offset += partition_sz

            def forward(self, x: Tensor) -> Tensor:
                activations = []

                for module in self.modulelist:
                    x = module(x)
                    activations.append(x)

                return activations

        ds_cfg = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "contiguous_gradients": True,
                "overlap_comm": True,
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 1.,
            }
        }

aiss's avatar
aiss committed
861
        with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=init_context_manager):
aiss's avatar
aiss committed
862
863
864
865
866
867
            model = ManyParamModel()

        ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg)

        for _ in range(3):  # test multiple iterations to cover prefetching
            activations: List[Tensor] = ds_engine(
aiss's avatar
aiss committed
868
                torch.ones((param_sz, ), dtype=torch.float16, device=ds_engine.device))
aiss's avatar
aiss committed
869
870
871
            assert len(activations) == n_layers

            partition_sz = math.ceil(param_sz / self.world_size)
aiss's avatar
aiss committed
872
            expected_activations = torch.empty(param_sz, dtype=torch.float16, device=ds_engine.device)
aiss's avatar
aiss committed
873
            for start_idx in range(0, param_sz, partition_sz):
aiss's avatar
aiss committed
874
                expected_activations[start_idx:start_idx + partition_sz] = dist.get_rank()
aiss's avatar
aiss committed
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894

            for layer_num, activation in enumerate(activations):
                expected_activations *= 2 * layer_num
                assert torch.allclose(activation, expected_activations)

            # TODO. finish writing this test
            ds_engine.backward(activations[-1].sum())

            avgd_gradients = ds_engine.optimizer.averaged_gradients
            assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group"
            weight_gradients: List[Tensor] = avgd_gradients[0]

            for layer_num, activation in enumerate(weight_gradients):
                pass


class TestZero3InitForParentWeightInitialization(DistributedTest):
    world_size = 4

    def test(self):
aiss's avatar
aiss committed
895

aiss's avatar
aiss committed
896
        class ModelWhereParentInitializesChildWeights(Module):
aiss's avatar
aiss committed
897

aiss's avatar
aiss committed
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
            def __init__(self) -> None:
                super().__init__()

                self.linear = Linear(12, 1)

                self.apply(self.__init_weights)

            def __init_weights(self, module):
                if isinstance(module, Linear):
                    with torch.no_grad():
                        module.weight.fill_(1 + dist.get_rank())

        ds_cfg = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "contiguous_gradients": True,
                "overlap_comm": True,
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 1.,
            }
        }

aiss's avatar
aiss committed
930
        with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=True):
aiss's avatar
aiss committed
931
932
933
            model = ModelWhereParentInitializesChildWeights()

        assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / self.world_size)
aiss's avatar
aiss committed
934
        assert torch.allclose(model.linear.weight.ds_tensor, torch.full_like(model.linear.weight.ds_tensor, 1))
aiss's avatar
aiss committed
935
936
937
938
939
940
941
942


@pytest.mark.skip("not working")
@pytest.mark.parametrize("param_persistence_threshold", [0, 10])
@pytest.mark.parametrize("contiguous_gradients", [True, False])
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
@pytest.mark.parametrize("prefetching", [True, False])
aiss's avatar
aiss committed
943
944
945
946
947
@pytest.mark.parametrize("model_class", [
    EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple,
    EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple,
    EltwiseMultiplicationTestNetwork_List
])
aiss's avatar
aiss committed
948
949
950
class TestZero3ParamPartitioningBaseBF16(DistributedTest):
    world_size = 2

aiss's avatar
aiss committed
951
952
    def test(self, param_persistence_threshold: int, contiguous_gradients: bool, offload_optimizer: bool,
             zero_grad: bool, prefetching: bool, model_class: EltwiseMultiplicationTestNetwork_Dict) -> None:
aiss's avatar
aiss committed
953
954
955
956
957
958
        if offload_optimizer and not contiguous_gradients:
            return

        m = 3
        n = 5
        weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)]
aiss's avatar
aiss committed
959
        model = model_class(*weights)
aiss's avatar
aiss committed
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
        prefetch_bucket_size = sum([p.numel() for p in model.parameters(recurse=True)])
        cfg = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "stage3_param_persistence_threshold": param_persistence_threshold,
                "contiguous_gradients": contiguous_gradients,
                "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "bf16": {
                "enabled": True,
                "loss_scale": 1.,
            }
        }

        if offload_optimizer:
            cfg["zero_optimization"]["offload_optimizer"] = {
                "device": "cpu",
                "pin_memory": True,
            }

        ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg)
        for i, weight in enumerate(weights):
aiss's avatar
aiss committed
990
            weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, (i + 1) * (1 + dist.get_rank()))
aiss's avatar
aiss committed
991
992
993
994
995

        def create_tensor(vals):
            return torch.as_tensor(vals, dtype=torch.bfloat16, device=ds_engine.device)

        expected_hidden1 = create_tensor([
aiss's avatar
aiss committed
996
997
998
            [1, 1, 1, 1, 1],
            [1, 1, 1, 2, 2],
            [2, 2, 2, 2, 2],
aiss's avatar
aiss committed
999
1000
        ])
        expected_hidden2 = create_tensor([
aiss's avatar
aiss committed
1001
1002
1003
            [2, 2, 2, 2, 2],
            [2, 2, 2, 8, 8],
            [8, 8, 8, 8, 8],
aiss's avatar
aiss committed
1004
        ])
aiss's avatar
aiss committed
1005
        expected_yhat = create_tensor([[6, 6, 6, 6, 6], [6, 6, 6, 48, 48], [48, 48, 48, 48, 48]])
aiss's avatar
aiss committed
1006
        expected_loss = create_tensor([
aiss's avatar
aiss committed
1007
1008
1009
            [5, 5, 5, 5, 5],
            [5, 5, 5, 47, 47],
            [47, 47, 47, 47, 47],
aiss's avatar
aiss committed
1010
1011
1012
1013
1014
        ])

        for train_iter in range(3):
            _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})
            activations = ds_engine(
aiss's avatar
aiss committed
1015
1016
                x=torch.ones((m, n), dtype=torch.bfloat16, device=ds_engine.device),
                y=torch.ones((m, n), dtype=torch.bfloat16, device=ds_engine.device),
aiss's avatar
aiss committed
1017
1018
1019
                use_module_trace=train_iter > 0,
                param_prefetching=prefetching and train_iter > 0,
            )
aiss's avatar
aiss committed
1020
1021
            # for ease in testing convert outputs to dict.
            activations = model_class.to_dict(activations)
aiss's avatar
aiss committed
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
            assert torch.allclose(activations["hidden1"], expected_hidden1)
            assert torch.allclose(activations["hidden2"], expected_hidden2)
            assert torch.allclose(activations["y_hat"], expected_yhat)
            assert torch.allclose(activations["loss"], expected_loss)

            ds_engine.backward(activations["loss"].sum())
            _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})

            # check the gradients
            grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions()
aiss's avatar
aiss committed
1032
1033
            assert set(grad_partitions.keys()) == {0
                                                   }, f"should have one parameter group but got {len(grad_partitions)}"
aiss's avatar
aiss committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
            assert set(grad_partitions[0].keys()) == {0, 1, 2}
            dloss_wrt_layer1 = grad_partitions[0][0]
            dloss_wrt_layer2 = grad_partitions[0][1]
            dloss_wrt_layer3 = grad_partitions[0][2]

            # layer1 = [..., 1, 2, ...]
            # layer2 = [..., 2, 4, ...]
            # layer3 = [..., 3, 6, ...]
            # dloss_wrt_layer3 = hidden2
            # dloss_wrt_layer2 = layer3 * hidden1
            # dloss_wrt_layer1 = layer3 * layer2 * x

            expected_grad_dtype = torch.float32 if offload_optimizer else torch.bfloat16

            grad_multiplier = 1 if zero_grad else (train_iter + 1)
            if dist.get_rank() == 0:
aiss's avatar
aiss committed
1050
1051
1052
1053
1054
1055
                assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype))
                assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype))
                assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype))
aiss's avatar
aiss committed
1056
1057
1058
            elif dist.get_rank() == 1:
                # parameters dont split evenly across ranks so rank 1 has a zero-padded
                # partition
aiss's avatar
aiss committed
1059
1060
1061
1062
1063
1064
                assert torch.allclose(dloss_wrt_layer3.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([8] * 7) + [0]).to(expected_grad_dtype))
                assert torch.allclose(dloss_wrt_layer2.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype))
                assert torch.allclose(dloss_wrt_layer1.to(get_accelerator().device_name()),
                                      grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype))
aiss's avatar
aiss committed
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
            else:
                raise RuntimeError("test has world size of two")

            if zero_grad:
                ds_engine.optimizer.zero_grad()

        # TODO. add testing for this - for now we just call it to make sure it
        # doesn't throw
        ds_engine.optimizer.step()
        _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE})


class TestZeroOffloadStage1(DistributedTest):
    world_size = 2

    def test(self):
        config_dict = {
            "train_batch_size": 4,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-4
                }
            },
            "fp16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": 1,
                "offload_optimizer": {
                    "device": "cpu"
                }
            }
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)
aiss's avatar
aiss committed
1104
1105
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        dist.barrier()
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()


@pytest.mark.parametrize('return_type', [tuple, list, dict])
class TestZero3DictFwd(DistributedTest):
    world_size = 1

    def test(self, return_type):
        config_dict = {
            "train_batch_size": 4,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-4
                }
            },
            "fp16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": 3
            }
        }
        hidden_dim = 10

        class MyModel(torch.nn.Module):
aiss's avatar
aiss committed
1137

aiss's avatar
aiss committed
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
            def __init__(self, hidden_dim):
                super(MyModel, self).__init__()
                self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
                self.cel = torch.nn.CrossEntropyLoss()

            def forward(self, x, y):
                x = self.l1(x)
                loss = self.cel(x, y)
                if return_type == dict:
                    val = {'a': x, 'loss': loss, 'b': 1, 'c': None}
                elif return_type == list:
                    val = [x, loss]
                elif return_type == tuple:
                    val = (x, loss)
                else:
                    raise NotImplementedError
                return val

        with deepspeed.zero.Init():
            model = MyModel(hidden_dim)

aiss's avatar
aiss committed
1159
1160
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
        dist.barrier()
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            if return_type == dict:
                loss = loss['loss']
            else:
                loss = loss[1]
            model.backward(loss)
            model.step()


@pytest.mark.parametrize('zero_stage', [1, 2, 3])
class TestZeroAdamOptimizerStepCount(DistributedTest):
    world_size = 1

    def test(self, zero_stage):
        # force all params to be partitioned by forcing threshold=0
        config_dict = {
            "train_micro_batch_size_per_gpu": 2,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "zero_optimization": {
                "stage": zero_stage,
                "stage3_param_persistence_threshold": 0,
                "sub_group_size": 4,
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-3
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            }
        }
        hidden_dim = 4

        model = SimpleModel(hidden_dim=hidden_dim, nlayers=12)
        model, optimizer, _, _ = deepspeed.initialize(config=config_dict,
                                                      model=model,
                                                      model_parameters=model.parameters())
aiss's avatar
aiss committed
1204
        data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248

        for i, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

            step_counts = []
            if zero_stage == 3:
                for sub_group_id, _ in enumerate(optimizer.fp16_groups):
                    fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id]
                    state = optimizer.optimizer.state[fp32_param]
                    step_counts.append(state['step'])
                assert all(step == step_counts[0] for step in step_counts)
            elif zero_stage == 1 or zero_stage == 2:
                for param_group in optimizer.optimizer.param_groups:
                    for param in param_group['params']:
                        state = optimizer.optimizer.state[param]
                        step_counts.append(state['step'])
                assert all(step == step_counts[0] for step in step_counts)


class TestZeroFrozenWeights(DistributedTest):
    world_size = 1

    def test(self):
        config_dict = {
            "train_batch_size": 4,
            "steps_per_print": 1,
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1e-4
                }
            },
            "fp16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": 3
            }
        }
        hidden_dim = 10

        class MyModel(torch.nn.Module):
aiss's avatar
aiss committed
1249

aiss's avatar
aiss committed
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            def __init__(self, hidden_dim):
                super(MyModel, self).__init__()
                self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
                self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)
                self.act = torch.nn.ReLU()
                self.cel = torch.nn.CrossEntropyLoss()

                # freeze one fc
                self.l2.weight.requires_grad = False
                self.l2.bias.requires_grad = False

            def forward(self, x, y):
                x = self.l1(x)
                x = self.act(x)
                x = self.l2(x)
                loss = self.cel(x, y)
                val = (x, loss)
                return val

        with deepspeed.zero.Init(config_dict_or_path=config_dict):
            model = MyModel(hidden_dim)

aiss's avatar
aiss committed
1272
1273
        model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
        data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
1274
1275
1276
1277
1278
1279
        dist.barrier()
        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            loss = loss[1]
            model.backward(loss)
            model.step()
aiss's avatar
aiss committed
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358


@pytest.mark.parametrize('force_ds_optim', [True, False])
class TestZeroOffloadOptim(DistributedTest):
    world_size = 1

    def test(self, force_ds_optim):
        config_dict = {
            "train_batch_size": 4,
            "gradient_accumulation_steps": 2,
            "steps_per_print": 1,
            "fp16": {
                "enabled": True
            },
            "zero_optimization": {
                "stage": 1,
                "offload_optimizer": {
                    "device": "cpu"
                }
            },
            "zero_force_ds_cpu_optimizer": force_ds_optim,
        }
        hidden_dim = 10

        model = SimpleModel(hidden_dim)

        optimizer = torch.optim.Adam(model.parameters())

        if force_ds_optim:
            with pytest.raises(ZeRORuntimeException):
                model, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict)
        else:
            model, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict)


@pytest.mark.parametrize('training', [True, False])
class TestZeroPartitionCache(DistributedTest):
    world_size = 1

    def test_training_partition_cache(self, training):
        hidden_dim = 10
        config_dict = {
            "train_batch_size": 2,
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            },
            "zero_optimization": {
                "stage": 3,
                "stage3_param_persistence_threshold": hidden_dim
            }
        }
        if training:
            config_dict["optimizer"] = {"type": "Adam"}

        with deepspeed.zero.Init(config_dict_or_path=config_dict):
            model = SimpleModel(hidden_dim, empty_grad=False)

        model, _, _, _ = deepspeed.initialize(model=model, config=config_dict)

        dtype = torch.half
        data_loader = random_dataloader(model=model,
                                        total_samples=6,
                                        hidden_dim=hidden_dim,
                                        device=model.device,
                                        dtype=dtype)

        for _, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            if training:
                model.backward(loss)
                model.step()

        persist_param_size = sum([p.numel() for p in model.parameters() if p.ds_persist])

        assert persist_param_size >= sum([p.numel() for p in model.parameters()])

        model.empty_partition_cache()
        assert sum([p.numel() for p in model.parameters()]) == 0