"tests/hooks/test_hooks.py" did not exist on "fb420664893956ecba4384fd8af9b375c7023d4d"
test_moe_checkpoint.py 3.73 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

from unit.common import DistributedTest
from unit.simple_model import *
from unit.util import required_torch_version

from unit.checkpoint.common import checkpoint_correctness_verification

import pytest


class TestMoECheckpoint(DistributedTest):
    world_size = 4

    @pytest.mark.parametrize("ep_size", [4])
    def test_checkpoint_moe(self, tmpdir, ep_size):
        if not required_torch_version():
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

aiss's avatar
aiss committed
25
        config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
aiss's avatar
aiss committed
26
27
        hidden_dim = 16

aiss's avatar
aiss committed
28
        models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
aiss's avatar
aiss committed
29
30
31
32
33
34
35
36
37
38
39
40
        optimizers = [torch.optim.AdamW(params=model.parameters()) for model in models]
        checkpoint_correctness_verification(config_dict,
                                            models=models,
                                            hidden_dim=hidden_dim,
                                            tmpdir=tmpdir,
                                            load_optimizer_states=True,
                                            load_lr_scheduler_states=False,
                                            fp16=config_dict["fp16"]["enabled"],
                                            empty_tag=True,
                                            base_optimizers=optimizers,
                                            seq_dataloader=True)

aiss's avatar
aiss committed
41
    @pytest.mark.parametrize("ep_size, load_optim_states", [(4, True), (4, False), (2, True), (2, False)])
aiss's avatar
aiss committed
42
43
44
45
46
47
48
49
50
51
52
    def test_checkpoint_moe_and_zero(self, tmpdir, ep_size, load_optim_states):
        if not required_torch_version():
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

        config_dict = {
            "train_batch_size": 8,
            "steps_per_print": 1,
            "optimizer": {
                "type": 'Adam',
                "params": {
                    "lr": 0.00015,
aiss's avatar
aiss committed
53
                    "betas": [0.8, 0.999],
aiss's avatar
aiss committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                    "eps": 1e-8,
                    "weight_decay": 3e-7
                }
            },
            "fp16": {
                "enabled": True,
                "initial_scale_power": 8
            },
            "zero_optimization": {
                "stage": 2,
            }
        }
        hidden_dim = 16

aiss's avatar
aiss committed
68
        models = [SimpleMoEModel(hidden_dim=hidden_dim, num_experts=ep_size, ep_size=ep_size) for _ in range(2)]
aiss's avatar
aiss committed
69
70
        # param group must have a random unique name (for now)
        # TODO: clean-up this requirement, the unique name should not be required here
aiss's avatar
aiss committed
71
72
        param_groups = [{'params': [p for p in model.parameters()], 'name': 'random-unique-name'} for model in models]
        params = [split_params_into_different_moe_groups_for_optimizer(group) for group in param_groups]
aiss's avatar
aiss committed
73
74
75
76
77
78
79
80
81
82
83
        optimizers = [torch.optim.AdamW(params=param) for param in params]
        checkpoint_correctness_verification(config_dict,
                                            models=models,
                                            hidden_dim=hidden_dim,
                                            tmpdir=tmpdir,
                                            load_optimizer_states=load_optim_states,
                                            load_lr_scheduler_states=False,
                                            fp16=config_dict["fp16"]["enabled"],
                                            empty_tag=True,
                                            base_optimizers=optimizers,
                                            seq_dataloader=True)