test_moe_layer.py 1.11 KB
Newer Older
Xuanlei Zhao's avatar
Xuanlei Zhao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import copy

import torch
from colossal_moe.models.mixtral_layer import MixtralSparseMLP
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock


class Config:
    def __init__(self, hidden_size, intermediate_size, num_local_experts, num_experts_per_tok, hidden_act):
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_local_experts = num_local_experts
        self.num_experts_per_tok = num_experts_per_tok
        self.hidden_act = hidden_act


def test_moe_layer():
    config = Config(hidden_size=4, intermediate_size=8, num_local_experts=32, num_experts_per_tok=2, hidden_act="silu")
    mistral_moe = MixtralSparseMoeBlock(config).cuda()
    colossal_moe = MixtralSparseMLP.from_native_module(copy.deepcopy(mistral_moe)).cuda()

    data = torch.randn(2, 8, 4).cuda()
    mistral_output = mistral_moe(data)[0]
    colossal_output = colossal_moe(data)[0]
    assert torch.allclose(
        mistral_output, colossal_output
    ), f"mistral_output: {mistral_output}\ncolossal_output: {colossal_output}"


if __name__ == "__main__":
    test_moe_layer()