test_fused_moe.py 8.25 KB
Newer Older
1
2
3
import unittest

import torch
4
5
import torch.nn.functional as F
from tqdm import tqdm
6
7

from sglang.srt.layers.activation import SiluAndMul
Ke Bao's avatar
Ke Bao committed
8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
10
11
12
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip
13
from sglang.test.test_utils import CustomTestCase
14

15
16
17
_is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()

18

19
class TestFusedMOE(CustomTestCase):
20
21
22
    NUM_EXPERTS = [8, 64]
    TOP_KS = [2, 6]

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    @staticmethod
    def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01):
        """Create a random CUDA tensor

        Args:
            shape: Tensor shape
            dtype: Data type
            mean: Mean value
            std: Standard deviation

        Returns:
            torch.Tensor: Randomly initialized CUDA tensor
        """
        return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std)

    def get_tolerance(self, dtype):
        """Get tolerance values for different data types

        Args:
            dtype: Data type

        Returns:
            tuple: (relative tolerance, absolute tolerance)
        """
        if dtype == torch.float32:
            return 1e-3, 1e-5
        elif dtype in [torch.float16, torch.bfloat16]:
            return 1e-1, 1e-2
        else:
            return 1e-2, 1e-2  # Default values for other types

54
55
56
57
58
59
60
61
62
63
64
65
    def torch_naive_moe(
        self,
        a,
        w1,
        w2,
        score,
        topk,
        w1_scale=None,
        w2_scale=None,
        a1_scale=None,
        a2_scale=None,
    ):
66
67
68
69
70
71
72
        B, D = a.shape
        a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
        out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
        score = torch.softmax(score, dim=-1, dtype=torch.float32)
        topk_weight, topk_ids = torch.topk(score, topk)
        topk_weight = topk_weight.view(-1)
        topk_ids = topk_ids.view(-1)
73

74
        if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]:
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            w1_compute = w1.to(a.dtype)
            w2_compute = w2.to(a.dtype)

            if w1_scale is not None:
                w1_compute = (w1_compute * w1_scale.view(-1, 1, 1)).to(a.dtype)
            if w2_scale is not None:
                w2_compute = (w2_compute * w2_scale.view(-1, 1, 1)).to(a.dtype)
            if a1_scale is not None:
                a = (a * a1_scale).to(a.dtype)
            if a2_scale is not None:
                a = (a * a2_scale).to(a.dtype)
        else:
            w1_compute = w1
            w2_compute = w2

        for i in range(w1_compute.shape[0]):
91
92
            mask = topk_ids == i
            if mask.sum():
93
94
95
96
                out[mask] = SiluAndMul()(
                    a[mask] @ w1_compute[i].transpose(0, 1)
                ) @ w2_compute[i].transpose(0, 1)

97
98
99
100
101
        return (
            out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
        ).sum(dim=1)

    def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False):
102
103
        rtol, atol = self.get_tolerance(dtype)

104
105
106
        if use_fp8_w8a8:
            # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
            capability = torch.cuda.get_device_capability()
107
            if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)):
108
109
                return

110
111
112
            a = self.create_random_cuda_tensor((m, k), dtype)
            w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
            w2 = self.create_random_cuda_tensor((e, k, n), dtype)
113
114
            w1 = w1.to(torch.float8_e4m3fn)
            w2 = w2.to(torch.float8_e4m3fn)
115
116
117
118
119
            score = self.create_random_cuda_tensor((m, e), dtype)
            w1_scale = self.create_random_cuda_tensor(e, torch.float32)
            w2_scale = self.create_random_cuda_tensor(e, torch.float32)
            a1_scale = self.create_random_cuda_tensor(1, torch.float32)
            a2_scale = self.create_random_cuda_tensor(1, torch.float32)
120

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            # Handle HIP case: normalize float8 weights so fused kernel doesn't break
            # on ROCm.
            if _is_fp8_fnuz:
                # Normalize to e4m3fnuz on HIP
                w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                    weight=w1,
                    weight_scale=w1_scale,
                    input_scale=a1_scale,
                )
                w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                    weight=w2,
                    weight_scale=w2_scale,
                    input_scale=a2_scale,
                )

136
137
138
            topk_output = select_experts(
                hidden_states=a,
                router_logits=score,
139
                topk_config=TopKConfig(top_k=topk, renormalize=False),
140
141
            )

142
            torch_output = self.torch_naive_moe(
143
144
145
146
147
148
149
150
151
                a,
                w1,
                w2,
                score,
                topk,
                w1_scale,
                w2_scale,
                a1_scale,
                a2_scale,
152
            )
153
154
155
156
157
158
159
160
161
162
163
164

            sglang_output = fused_moe(
                a,
                w1,
                w2,
                topk_output,
                use_fp8_w8a8=True,
                w1_scale=w1_scale,
                w2_scale=w2_scale,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
            )
165
166
            torch.testing.assert_close(
                sglang_output, torch_output, rtol=rtol, atol=atol
167
168
            )
        else:
169
170
171
172
            a = self.create_random_cuda_tensor((m, k), dtype)
            w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
            w2 = self.create_random_cuda_tensor((e, k, n), dtype)
            score = self.create_random_cuda_tensor((m, e), dtype)
173

174
175
176
            topk_output = select_experts(
                hidden_states=a,
                router_logits=score,
177
                topk_config=TopKConfig(top_k=topk, renormalize=False),
178
179
180
            )

            triton_output = fused_moe(a, w1, w2, topk_output)
181
            torch_output = self.torch_naive_moe(a, w1, w2, score, topk)
182
183
184
            torch.testing.assert_close(
                triton_output, torch_output, rtol=rtol, atol=atol
            )
185
186

    def test_various_configurations(self):
187
188
        m_values = [1, 33, 64, 222]
        n_values = [128, 1024]
189
190
191
192
        k_values = [128, 511, 1024]
        dtypes = [torch.float16, torch.bfloat16]
        fp8_modes = [False, True]

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        # Calculate total number of tests
        total_tests = (
            len(m_values)
            * len(n_values)
            * len(k_values)
            * len(self.NUM_EXPERTS)
            * len(self.TOP_KS)
            * len(dtypes)
            * len(fp8_modes)
        )

        # Create progress bar
        with tqdm(total=total_tests, desc="Running MoE tests") as pbar:
            for m in m_values:
                for n in n_values:
                    for k in k_values:
                        for e in self.NUM_EXPERTS:
                            for topk in self.TOP_KS:
                                for dtype in dtypes:
                                    for use_fp8_w8a8 in fp8_modes:
                                        with self.subTest(
                                            m=m,
                                            n=n,
                                            k=k,
                                            e=e,
                                            topk=topk,
                                            dtype=dtype,
                                            fp8=use_fp8_w8a8,
                                        ):
                                            self._test_case(
                                                m,
                                                n,
                                                k,
                                                e,
                                                topk,
                                                dtype,
                                                use_fp8_w8a8=use_fp8_w8a8,
                                            )
231
                                            torch.cuda.empty_cache()
232
                                        pbar.update(1)
233
234
235
236


if __name__ == "__main__":
    unittest.main()