test_fused_moe.py 6.9 KB
Newer Older
1
2
3
import unittest

import torch
4
5
import torch.nn.functional as F
from tqdm import tqdm
6
7

from sglang.srt.layers.activation import SiluAndMul
Ke Bao's avatar
Ke Bao committed
8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
from sglang.test.test_utils import CustomTestCase
10
11


12
class TestFusedMOE(CustomTestCase):
13
14
15
    NUM_EXPERTS = [8, 64]
    TOP_KS = [2, 6]

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    @staticmethod
    def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
        """Create a random CUDA tensor

        Args:
            shape: Tensor shape
            dtype: Data type
            mean: Mean value
            std: Standard deviation

        Returns:
            torch.Tensor: Randomly initialized CUDA tensor
        """
        return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)

    def get_tolerance(self, dtype):
        """Get tolerance values for different data types

        Args:
            dtype: Data type

        Returns:
            tuple: (relative tolerance, absolute tolerance)
        """
        if dtype == torch.float32:
            return 1e-3, 1e-5
        elif dtype in [torch.float16, torch.bfloat16]:
            return 1e-1, 1e-2
        else:
            return 1e-2, 1e-2  # Default values for other types

47
48
49
50
51
52
53
54
55
56
57
58
    def torch_naive_moe(
        self,
        a,
        w1,
        w2,
        score,
        topk,
        w1_scale=None,
        w2_scale=None,
        a1_scale=None,
        a2_scale=None,
    ):
59
60
61
62
63
64
65
        B, D = a.shape
        a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
        out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
        score = torch.softmax(score, dim=-1, dtype=torch.float32)
        topk_weight, topk_ids = torch.topk(score, topk)
        topk_weight = topk_weight.view(-1)
        topk_ids = topk_ids.view(-1)
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

        if w1.dtype == torch.float8_e4m3fn:
            w1_compute = w1.to(a.dtype)
            w2_compute = w2.to(a.dtype)

            if w1_scale is not None:
                w1_compute = (w1_compute * w1_scale.view(-1, 1, 1)).to(a.dtype)
            if w2_scale is not None:
                w2_compute = (w2_compute * w2_scale.view(-1, 1, 1)).to(a.dtype)
            if a1_scale is not None:
                a = (a * a1_scale).to(a.dtype)
            if a2_scale is not None:
                a = (a * a2_scale).to(a.dtype)
        else:
            w1_compute = w1
            w2_compute = w2

        for i in range(w1_compute.shape[0]):
84
85
            mask = topk_ids == i
            if mask.sum():
86
87
88
89
                out[mask] = SiluAndMul()(
                    a[mask] @ w1_compute[i].transpose(0, 1)
                ) @ w2_compute[i].transpose(0, 1)

90
91
92
93
94
        return (
            out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
        ).sum(dim=1)

    def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
95
96
        rtol, atol = self.get_tolerance(dtype)

97
98
99
100
101
102
        if use_fp8_w8a8:
            # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
            capability = torch.cuda.get_device_capability()
            if not (capability[0] >= 9 or capability == (8, 9)):
                return

103
104
105
            a = self.create_random_cuda_tensor((m, k), dtype)
            w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
            w2 = self.create_random_cuda_tensor((e, k, n), dtype)
106
107
            w1 = w1.to(torch.float8_e4m3fn)
            w2 = w2.to(torch.float8_e4m3fn)
108
            score = self.create_random_cuda_tensor((m, e), dtype)
109

110
111
112
113
            w1_scale = self.create_random_cuda_tensor(e, torch.float32)
            w2_scale = self.create_random_cuda_tensor(e, torch.float32)
            a1_scale = self.create_random_cuda_tensor(1, torch.float32)
            a2_scale = self.create_random_cuda_tensor(1, torch.float32)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

            sglang_output = fused_moe(
                a,
                w1,
                w2,
                score,
                topk,
                renormalize=False,
                use_fp8_w8a8=True,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
            )

129
130
131
132
133
            torch_output = self.torch_naive_moe(
                a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
            )
            torch.testing.assert_close(
                sglang_output, torch_output, rtol=rtol, atol=atol
134
135
136
            )

        else:
137
138
139
140
            a = self.create_random_cuda_tensor((m, k), dtype)
            w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
            w2 = self.create_random_cuda_tensor((e, k, n), dtype)
            score = self.create_random_cuda_tensor((m, e), dtype)
141
142
143

            triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
            torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
144
145
146
            torch.testing.assert_close(
                triton_output, torch_output, rtol=rtol, atol=atol
            )
147
148

    def test_various_configurations(self):
149
150
        m_values = [1, 33, 64, 222]
        n_values = [128, 1024]
151
152
153
154
        k_values = [128, 511, 1024]
        dtypes = [torch.float16, torch.bfloat16]
        fp8_modes = [False, True]

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        # Calculate total number of tests
        total_tests = (
            len(m_values)
            * len(n_values)
            * len(k_values)
            * len(self.NUM_EXPERTS)
            * len(self.TOP_KS)
            * len(dtypes)
            * len(fp8_modes)
        )

        # Create progress bar
        with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
            for m in m_values:
                for n in n_values:
                    for k in k_values:
                        for e in self.NUM_EXPERTS:
                            for topk in self.TOP_KS:
                                for dtype in dtypes:
                                    for use_fp8_w8a8 in fp8_modes:
                                        with self.subTest(
                                            m=m,
                                            n=n,
                                            k=k,
                                            e=e,
                                            topk=topk,
                                            dtype=dtype,
                                            fp8=use_fp8_w8a8,
                                        ):
                                            self._test_case(
                                                m,
                                                n,
                                                k,
                                                e,
                                                topk,
                                                dtype,
                                                use_fp8_w8a8=use_fp8_w8a8,
                                            )
193
                                            torch.cuda.empty_cache()
194
                                        pbar.update(1)
195
196
197
198


if __name__ == "__main__":
    unittest.main()