test_batched_moe.py 9.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8

from dataclasses import dataclass

import pytest
import torch

9
10
11
12
13
14
from tests.kernels.moe.utils import (
    batched_moe,
    make_quantized_test_activations,
    make_test_weights,
    naive_batched_moe,
)
bnellnm's avatar
bnellnm committed
15
16
17
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
18
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
19
20
    invoke_moe_batched_triton_kernel,
)
bnellnm's avatar
bnellnm committed
21
22
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
23
from vllm.triton_utils import tl
bnellnm's avatar
bnellnm committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 2048),
    (1, 512, 512),
    (1, 1024, 128),
    (1, 1024, 2048),
    (32, 128, 128),
    (32, 512, 512),
    (32, 1024, 2048),
    (45, 128, 128),
    (45, 128, 2048),
    (45, 512, 512),
    (45, 1024, 128),
    (45, 1024, 2048),
    (64, 512, 512),
    (64, 1024, 2048),
    (222, 128, 128),
    (222, 128, 2048),
    (222, 1024, 128),
    (222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
52
53
54
55


@dataclass
class BatchedMMConfig:
bnellnm's avatar
bnellnm committed
56
    in_dtype: torch.dtype
57
    quant_dtype: torch.dtype | None
bnellnm's avatar
bnellnm committed
58
    out_dtype: torch.dtype
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    num_experts: int
    max_tokens_per_expert: int
    K: int
    N: int


@dataclass
class BatchedMMTensors:
    A: torch.Tensor  # [E, max_tokens, K]
    B: torch.Tensor  # [E, K, N] - column major
    C: torch.Tensor  # [E, max_tokens, N]
    num_expert_tokens: torch.Tensor  # [E]

    @staticmethod
    def make_tensors(config: BatchedMMConfig):
74
75
76
77
78
79
80
81
82
83
        A = (
            torch.randn(
                (config.num_experts, config.max_tokens_per_expert, config.K),
                device="cuda",
                dtype=config.in_dtype,
            )
            / 10
        )
        B = torch.randn(
            (config.num_experts, config.N, config.K),
84
            device="cuda",
85
86
            dtype=config.in_dtype,
        )
87
88
89
        C = torch.zeros(
            (config.num_experts, config.max_tokens_per_expert, config.N),
            device="cuda",
90
91
            dtype=config.out_dtype,
        )
bnellnm's avatar
bnellnm committed
92

93
94
95
96
97
98
99
        num_expert_tokens = torch.randint(
            low=0,
            high=config.max_tokens_per_expert,
            size=(config.num_experts,),
            device="cuda",
            dtype=torch.int32,
        )
100

bnellnm's avatar
bnellnm committed
101
        return BatchedMMTensors(A, B, C, num_expert_tokens)
102
103


104
105
106
107
108
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
109
110
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
111
112
113
114
115
116
def test_batched_mm(
    num_experts: int,
    max_tokens_per_expert: int,
    K: int,
    N: int,
    dtype: torch.dtype,
117
    block_shape: list[int] | None,
118
119
    per_act_token_quant: bool,
):
bnellnm's avatar
bnellnm committed
120
    current_platform.seed_everything(7)
121

bnellnm's avatar
bnellnm committed
122
    use_fp8_w8a8 = dtype == torch.float8_e4m3fn
123

bnellnm's avatar
bnellnm committed
124
125
126
127
128
129
130
131
132
133
134
135
136
    if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
        pytest.skip("Don't test blocking for non-quantized types.")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization test.")

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

137
138
139
140
141
142
143
    num_expert_tokens = torch.randint(
        low=0,
        high=max_tokens_per_expert,
        size=(num_experts,),
        device="cuda",
        dtype=torch.int32,
    )
bnellnm's avatar
bnellnm committed
144
145
146
147
148
149
150
151

    A, A_q, A_scale = make_quantized_test_activations(
        num_experts,
        max_tokens_per_expert,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
152
153
        per_act_token_quant=per_act_token_quant,
    )
bnellnm's avatar
bnellnm committed
154

155
    (B, B_q, B_scale, _), _ = make_test_weights(
bnellnm's avatar
bnellnm committed
156
157
158
159
160
161
        num_experts,
        N // 2,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
162
        per_out_ch_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
163
164
165
166
167
168
    )

    out_shape = (num_experts, max_tokens_per_expert, N)
    test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
169
170
171
172

    compute_tl_dtype = {
        torch.float16: tl.float16,
        torch.bfloat16: tl.bfloat16,
173
        torch.float32: tl.float32,
174
    }[test_output.dtype]
bnellnm's avatar
bnellnm committed
175
176
177

    assert A_q.dtype == B_q.dtype

178
    invoke_moe_batched_triton_kernel(
bnellnm's avatar
bnellnm committed
179
180
        A_q,
        B_q,
181
        test_output,
bnellnm's avatar
bnellnm committed
182
        num_expert_tokens,
183
184
        compute_tl_dtype,
        # Quantization data
bnellnm's avatar
bnellnm committed
185
186
        A_scale,
        B_scale,
187
188
        None,
        # Quantization schemes
bnellnm's avatar
bnellnm committed
189
        use_fp8_w8a8,
190
191
192
193
194
        False,
        False,
        config={
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 16,
195
            "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
bnellnm's avatar
bnellnm committed
196
        },
197
        per_act_token_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
198
199
        block_shape=block_shape,
    )
200

bnellnm's avatar
bnellnm committed
201
202
203
204
205
206
207
    ref_output = native_batched_masked_quant_matmul(
        A,
        B,
        ref_output,
        num_expert_tokens,
    )

208
209
210
211
212
213
214
215
216
217
    q_ref_output = native_batched_masked_quant_matmul(
        A_q,
        B_q,
        q_ref_output,
        num_expert_tokens,
        A_scale,
        B_scale,
        block_shape,
        per_act_token_quant,
    )
218
219
220
221
222
223
224

    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[test_output.dtype]

225
    torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
bnellnm's avatar
bnellnm committed
226
227
228
229
230
231
    torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
232
233
234
235
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
bnellnm's avatar
bnellnm committed
236
237
238
239
240
241
242
243
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
    per_act_token_quant: bool,
244
    block_shape: list[int] | None,
245
    input_scales: bool,
bnellnm's avatar
bnellnm committed
246
247
248
249
250
):
    current_platform.seed_everything(7)

    use_fp8_w8a8 = dtype == torch.float8_e4m3fn

251
252
253
    if topk > e:
        pytest.skip("topk > e")

bnellnm's avatar
bnellnm committed
254
255
256
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

257
    if per_act_token_quant and block_shape is not None:
bnellnm's avatar
bnellnm committed
258
259
260
261
262
263
264
265
266
267
268
269
        pytest.skip("Skip illegal quantization test.")

    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

270
    (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
271
272
273
274
275
276
        e,
        n,
        k,
        block_shape=block_shape,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
277
        per_out_ch_quant=per_act_token_quant,
278
279
280
281
282
283
284
285
    )

    if input_scales and quant_dtype is not None:
        a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None
bnellnm's avatar
bnellnm committed
286
287
288

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
289
290

        baseline_output = torch_experts(
bnellnm's avatar
bnellnm committed
291
292
293
294
295
296
297
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
298
299
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
300
301
302
303
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )
304
305

        batched_output = naive_batched_moe(
bnellnm's avatar
bnellnm committed
306
307
308
309
310
311
312
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
313
314
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
315
316
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
317
318
            block_shape=block_shape,
        )
bnellnm's avatar
bnellnm committed
319

320
        triton_output = batched_moe(
bnellnm's avatar
bnellnm committed
321
322
323
324
325
326
327
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
328
329
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
330
331
332
333
334
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )

335
    torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
bnellnm's avatar
bnellnm committed
336

337
    torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)