test_moe.py 4 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
from deepspeed.utils import groups
import torch
import torch.distributed as dist
import deepspeed
import argparse
import pytest
import json
import os
from deepspeed.ops.adam import FusedAdam
from .common import distributed_test
from deepspeed.ops.op_builder import CPUAdamBuilder
from .simple_model import SimpleModel, SimplePRMoEModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader
from .util import required_torch_version

try:
    from apex import amp
    _amp_available = True
except ImportError:
    _amp_available = False
amp_available = pytest.mark.skip(_amp_available, reason="apex/amp is not installed")


@pytest.mark.parametrize("ep_size, use_residual",
                         [(2,
                           True),
                          (2,
                           False),
                          (4,
                           True),
                          (4,
                           False)])
def test_moe(tmpdir, ep_size, use_residual):
    if not required_torch_version():
        pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

    config_dict = {
        "train_batch_size": 8,
        "steps_per_print": 1,
        "fp16": {
            "enabled": True
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 16

    @distributed_test(world_size=[4])
    def _test_moe(args, hidden_dim, ep_size, use_residual):
        # E+D -- ep_size = 2
        # E only -- ep_size = 4
        model = SimpleMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)
        #dist_init_required=False -- parameterize to True/False?

        data_loader = sequence_dataloader(model=model,
                                          total_samples=50,
                                          hidden_dim=hidden_dim,
                                          device=model.device)

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

    _test_moe(args=args,
              hidden_dim=hidden_dim,
              ep_size=ep_size,
              use_residual=use_residual)


@pytest.mark.parametrize("ep_size, use_residual", [(2, True), (2, False)])
def test_pr_moe(tmpdir, ep_size, use_residual):
    if not required_torch_version():
        pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

    config_dict = {
        "train_batch_size": 8,
        "steps_per_print": 1,
        "fp16": {
            "enabled": True
        }
    }
    args = args_from_dict(tmpdir, config_dict)
    hidden_dim = 16

    @distributed_test(world_size=[4])
    def _test_moe(args, hidden_dim, ep_size, use_residual):
        # E+D -- ep_size = 2
        # E only -- ep_size = 4

        model = SimplePRMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(args=args,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)

        data_loader = sequence_dataloader(model=model,
                                          total_samples=50,
                                          hidden_dim=hidden_dim,
                                          device=model.device)

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()

    _test_moe(args=args,
              hidden_dim=hidden_dim,
              ep_size=ep_size,
              use_residual=use_residual)