test_batched_moe.py 9.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
bnellnm's avatar
bnellnm committed
5
from typing import Optional
6
7
8
9

import pytest
import torch

10
11
12
13
14
15
from tests.kernels.moe.utils import (
    batched_moe,
    make_quantized_test_activations,
    make_test_weights,
    naive_batched_moe,
)
bnellnm's avatar
bnellnm committed
16
17
18
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
19
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
20
21
    invoke_moe_batched_triton_kernel,
)
bnellnm's avatar
bnellnm committed
22
23
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
24
from vllm.triton_utils import tl
bnellnm's avatar
bnellnm committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 2048),
    (1, 512, 512),
    (1, 1024, 128),
    (1, 1024, 2048),
    (32, 128, 128),
    (32, 512, 512),
    (32, 1024, 2048),
    (45, 128, 128),
    (45, 128, 2048),
    (45, 512, 512),
    (45, 1024, 128),
    (45, 1024, 2048),
    (64, 512, 512),
    (64, 1024, 2048),
    (222, 128, 128),
    (222, 128, 2048),
    (222, 1024, 128),
    (222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
53
54
55
56


@dataclass
class BatchedMMConfig:
bnellnm's avatar
bnellnm committed
57
58
59
    in_dtype: torch.dtype
    quant_dtype: Optional[torch.dtype]
    out_dtype: torch.dtype
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    num_experts: int
    max_tokens_per_expert: int
    K: int
    N: int


@dataclass
class BatchedMMTensors:
    A: torch.Tensor  # [E, max_tokens, K]
    B: torch.Tensor  # [E, K, N] - column major
    C: torch.Tensor  # [E, max_tokens, N]
    num_expert_tokens: torch.Tensor  # [E]

    @staticmethod
    def make_tensors(config: BatchedMMConfig):
75
76
77
78
79
80
81
82
83
84
        A = (
            torch.randn(
                (config.num_experts, config.max_tokens_per_expert, config.K),
                device="cuda",
                dtype=config.in_dtype,
            )
            / 10
        )
        B = torch.randn(
            (config.num_experts, config.N, config.K),
85
            device="cuda",
86
87
            dtype=config.in_dtype,
        )
88
89
90
        C = torch.zeros(
            (config.num_experts, config.max_tokens_per_expert, config.N),
            device="cuda",
91
92
            dtype=config.out_dtype,
        )
bnellnm's avatar
bnellnm committed
93

94
95
96
97
98
99
100
        num_expert_tokens = torch.randint(
            low=0,
            high=config.max_tokens_per_expert,
            size=(config.num_experts,),
            device="cuda",
            dtype=torch.int32,
        )
101

bnellnm's avatar
bnellnm committed
102
        return BatchedMMTensors(A, B, C, num_expert_tokens)
103
104


105
106
107
108
109
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
110
111
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
112
113
114
115
116
117
118
119
120
def test_batched_mm(
    num_experts: int,
    max_tokens_per_expert: int,
    K: int,
    N: int,
    dtype: torch.dtype,
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
):
bnellnm's avatar
bnellnm committed
121
    current_platform.seed_everything(7)
122

bnellnm's avatar
bnellnm committed
123
    use_fp8_w8a8 = dtype == torch.float8_e4m3fn
124

bnellnm's avatar
bnellnm committed
125
126
127
128
129
130
131
132
133
134
135
136
137
    if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
        pytest.skip("Don't test blocking for non-quantized types.")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization test.")

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

138
139
140
141
142
143
144
    num_expert_tokens = torch.randint(
        low=0,
        high=max_tokens_per_expert,
        size=(num_experts,),
        device="cuda",
        dtype=torch.int32,
    )
bnellnm's avatar
bnellnm committed
145
146
147
148
149
150
151
152

    A, A_q, A_scale = make_quantized_test_activations(
        num_experts,
        max_tokens_per_expert,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
153
154
        per_act_token_quant=per_act_token_quant,
    )
bnellnm's avatar
bnellnm committed
155

156
    (B, B_q, B_scale, _), _ = make_test_weights(
bnellnm's avatar
bnellnm committed
157
158
159
160
161
162
        num_experts,
        N // 2,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
163
        per_out_ch_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
164
165
166
167
168
169
    )

    out_shape = (num_experts, max_tokens_per_expert, N)
    test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
170
171
172
173

    compute_tl_dtype = {
        torch.float16: tl.float16,
        torch.bfloat16: tl.bfloat16,
174
        torch.float32: tl.float32,
175
    }[test_output.dtype]
bnellnm's avatar
bnellnm committed
176
177
178

    assert A_q.dtype == B_q.dtype

179
    invoke_moe_batched_triton_kernel(
bnellnm's avatar
bnellnm committed
180
181
        A_q,
        B_q,
182
        test_output,
bnellnm's avatar
bnellnm committed
183
        num_expert_tokens,
184
185
        compute_tl_dtype,
        # Quantization data
bnellnm's avatar
bnellnm committed
186
187
        A_scale,
        B_scale,
188
189
        None,
        # Quantization schemes
bnellnm's avatar
bnellnm committed
190
        use_fp8_w8a8,
191
192
193
194
195
        False,
        False,
        config={
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 16,
196
            "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32,
bnellnm's avatar
bnellnm committed
197
        },
198
        per_act_token_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
199
200
        block_shape=block_shape,
    )
201

bnellnm's avatar
bnellnm committed
202
203
204
205
206
207
208
    ref_output = native_batched_masked_quant_matmul(
        A,
        B,
        ref_output,
        num_expert_tokens,
    )

209
210
211
212
213
214
215
216
217
218
    q_ref_output = native_batched_masked_quant_matmul(
        A_q,
        B_q,
        q_ref_output,
        num_expert_tokens,
        A_scale,
        B_scale,
        block_shape,
        per_act_token_quant,
    )
219
220
221
222
223
224
225

    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[test_output.dtype]

226
    torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
bnellnm's avatar
bnellnm committed
227
228
229
230
231
232
    torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
233
234
235
236
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
bnellnm's avatar
bnellnm committed
237
238
239
240
241
242
243
244
245
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
246
    input_scales: bool,
bnellnm's avatar
bnellnm committed
247
248
249
250
251
):
    current_platform.seed_everything(7)

    use_fp8_w8a8 = dtype == torch.float8_e4m3fn

252
253
254
    if topk > e:
        pytest.skip("topk > e")

bnellnm's avatar
bnellnm committed
255
256
257
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

258
    if per_act_token_quant and block_shape is not None:
bnellnm's avatar
bnellnm committed
259
260
261
262
263
264
265
266
267
268
269
270
        pytest.skip("Skip illegal quantization test.")

    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

271
    (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
272
273
274
275
276
277
        e,
        n,
        k,
        block_shape=block_shape,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
278
        per_out_ch_quant=per_act_token_quant,
279
280
281
282
283
284
285
286
    )

    if input_scales and quant_dtype is not None:
        a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None
bnellnm's avatar
bnellnm committed
287
288
289

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
290
291

        baseline_output = torch_experts(
bnellnm's avatar
bnellnm committed
292
293
294
295
296
297
298
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
299
300
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
301
302
303
304
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )
305
306

        batched_output = naive_batched_moe(
bnellnm's avatar
bnellnm committed
307
308
309
310
311
312
313
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
314
315
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
316
317
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
318
319
            block_shape=block_shape,
        )
bnellnm's avatar
bnellnm committed
320

321
        triton_output = batched_moe(
bnellnm's avatar
bnellnm committed
322
323
324
325
326
327
328
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
329
330
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
331
332
333
334
335
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )

336
    torch.testing.assert_close(batched_output, baseline_output, atol=3e-2, rtol=2e-2)
bnellnm's avatar
bnellnm committed
337

338
    torch.testing.assert_close(triton_output, batched_output, atol=2e-2, rtol=2e-2)