test_moe_zero_init.py 3.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from functools import partial

import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn

import colossalai
from colossalai.context import MOE_CONTEXT
from colossalai.logging import get_dist_logger
from colossalai.nn import CheckpointModule
from colossalai.nn.layer import MoeModule
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
from colossalai.zero.legacy.init_ctx import ZeroInitContext
from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from tests.test_zero.common import CONFIG


class MoeModel(nn.Module):

    def __init__(self, checkpoint: bool = False):

        class TestSubModule(CheckpointModule):

            def __init__(self):
                super().__init__(checkpoint)
                expert_cls = nn.Linear
                expert_args_dict = dict(in_features=16, out_features=16)
                self.moe = MoeModule(dim_model=16,
                                     num_experts=8,
                                     use_residual=True,
                                     expert_cls=expert_cls,
                                     **expert_args_dict)
                self.proj = nn.Linear(16, 4)

            def _forward(self, x):
                x, y = self.moe(x)
                x = self.proj(x)
                return x, y

        super().__init__()
        self.test_embed = nn.Linear(4, 16)
        self.test_transform = TestSubModule()

    def forward(self, x):
        MOE_CONTEXT.reset_loss()

        x = self.test_embed(x)
        x, y = self.test_transform(x)

        MOE_CONTEXT.add_loss(y)
        return x


@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_moe_zero_init(init_device_type, shard_strategy_class):
    logger = get_dist_logger("test_moe_zero_init")

    if init_device_type == 'cuda':
        init_device = get_current_device()
    elif init_device_type == 'cpu':
        init_device = torch.device("cpu")
    else:
        raise NotImplementedError("Unknown device found.")

    model_numel_tensor = torch.zeros(1, dtype=torch.int)
    with ZeroInitContext(target_device=init_device,
                         shard_strategy=shard_strategy_class(),
                         shard_param=True,
                         model_numel_tensor=model_numel_tensor):
        model = MoeModel(checkpoint=True)

    for name, param in model.named_parameters():
        assert hasattr(param, 'colo_attr')

        # the parameters in moe experts and its gate should not be sharded
        if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
            assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
        else:
            assert param.colo_attr.sharded_data_tensor.is_sharded

        # the parameters in moe experts is not replicated
        if 'experts' in name:
            assert not param.colo_attr.is_replicated
        else:
            assert param.colo_attr.is_replicated

        if param.colo_attr.param_is_sharded:
            assert param.colo_attr.data_payload.device.type == init_device.type, \
                f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
        else:
            assert param.colo_attr.data_payload.device.type == 'cuda'


def _run_dist(rank, world_size, port):
    colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
    MOE_CONTEXT.setup(seed=42)
    run_moe_zero_init()


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use()
def test_moe_zero_init(world_size):
    run_func = partial(_run_dist, world_size=world_size, port=free_port())
    mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
    test_moe_zero_init(world_size=2)