"tutorials/rnn/rnn.py" did not exist on "405348050cff4d2d329669850da2a96a84946b08"
test_moe.py 2.68 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

import torch
import deepspeed
import pytest
from unit.common import DistributedTest
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
from unit.util import required_torch_version


@pytest.mark.parametrize("ep_size", [2, 4])
@pytest.mark.parametrize("use_residual", [True, False])
class TestMoE(DistributedTest):
    world_size = 4

    def test(self, ep_size, use_residual):
        if not required_torch_version():
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

aiss's avatar
aiss committed
23
        config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
aiss's avatar
aiss committed
24
25
26
27
28
29
30
31
32
33
34
35
        hidden_dim = 16

        # E+D -- ep_size = 2
        # E only -- ep_size = 4
        model = SimpleMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(config=config_dict,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)
        #dist_init_required=False -- parameterize to True/False?

aiss's avatar
aiss committed
36
        data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()


@pytest.mark.parametrize("ep_size, use_residual", [(2, True), (2, False)])
class TestPRMoE(DistributedTest):
    world_size = 4

    def test(self, ep_size, use_residual):
        if not required_torch_version():
            pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly")

aiss's avatar
aiss committed
52
        config_dict = {"train_batch_size": 8, "steps_per_print": 1, "fp16": {"enabled": True}}
aiss's avatar
aiss committed
53
54
55
56
57
58
59
60
61
62
63
        hidden_dim = 16

        # E+D -- ep_size = 2
        # E only -- ep_size = 4
        model = SimplePRMoEModel(hidden_dim, ep_size=ep_size, use_residual=use_residual)
        optimizer = torch.optim.AdamW(params=model.parameters())
        model, _, _, _ = deepspeed.initialize(config=config_dict,
                                              model=model,
                                              optimizer=optimizer,
                                              dist_init_required=False)

aiss's avatar
aiss committed
64
        data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device)
aiss's avatar
aiss committed
65
66
67
68
69

        for n, batch in enumerate(data_loader):
            loss = model(batch[0], batch[1])
            model.backward(loss)
            model.step()