test_flashinfer.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
from vllm.model_executor.layers.fused_moe.config import (
12
13
    FusedMoEConfig,
    FusedMoEParallelConfig,
14
    FusedMoEQuantConfig,
15
    RoutingMethodType,
16
17
    fp8_w8a8_moe_quant_config,
)
18
19
20
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
21
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
22
23
24
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
25
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
26
    apply_fi_trtllm_fp8_per_tensor_moe,
27
    register_scales_for_trtllm_fp8_per_tensor_moe,
28
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
29
30
31
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
32
33
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
34
from vllm.utils.torch_utils import set_random_seed
35
36
37
38
39
40
41
42

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
43

44
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
45
    90
46
47
):
    pytest.skip(
48
        "Supported for sm >= 90",
49
50
        allow_module_level=True,
    )
51
52
53
54
55
56
57
58
59
60
61
62
63

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

64
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
65
66
67
68
69
70
71
72
73


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
74
75
        if a_global_sf.numel() == 1:
            a_global_sf = a_global_sf.view(1, 1)
76
77
78
79
80
81
82
83
84
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


85
86
87
88
89
90
91
92
93
94
95
96
97
98
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
    close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
    match_ratio = close.float().mean()
    assert match_ratio >= percent, (
        f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
    )

    mismatch_percent = 1.0 - match_ratio.item()
    assert mismatch_percent <= 1 - percent, (
        f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
        f"{1 - percent:.4f}"
    )


99
100
101
102
103
104
105
106
107
108
109
110
@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
111
    def make_moe_tensors_8bit(
112
113
114
115
116
117
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
118
    ) -> "TestData":
119
        is_gated = activation.is_gated
120

121
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
122
123
124
125
126
        w13 = (
            torch.randn(
                (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
            )
            / 10
127
        )
128
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
129
130
131

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
132
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
133
134
135
136
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
137
        layer.orig_dtype = torch.bfloat16
138
139
140
141
142
143
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
144
        layer.activation = activation
145
        # Setup dummy config.
146
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
147
148

        # flashinfer expects swapped rows for w13
149
150
        if is_gated:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
151
        if is_trtllm:
152
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
153
                layer.w13_weight, layer.w2_weight, is_gated
154
            )
155
156
157
158
159
160
161
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
162
        layer.custom_routing_function = Llama4MoE.custom_routing_function
163
164
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
184
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
185
186
187
188
189
190
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
191
    activation: MoEActivation,
192
193
    monkeypatch,
):
194
195
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
196
    set_random_seed(7)
197
198
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
199
200
201
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, is_trtllm=True, activation=activation
        )
202
203

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
204
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
205
            hidden_states=td.hidden_states,
206
207
            gating_output=score,
            topk=topk,
208
            renormalize=False,
209
        )
210

211
212
213
214
215
216
217
218
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

219
220
221
222
223
224
225
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
226
            activation=activation,
227
228
229
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
230
            quant_config=quant_config,
231
232
        )

233
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
234
235
236
237
238
239
240
241
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
242
243
            apply_router_weight_on_input=True,
        )
244

245
246
247
248
249
250
251
        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
        )
252
253
254
255
256


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
257
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
258
259
260
261
262
263
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
264
    activation: MoEActivation,
265
    monkeypatch,
266
    workspace_init,
267
):
268
    set_random_seed(7)
269
270
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
271
        td = TestData.make_moe_tensors_8bit(
272
            m, k, n, e, is_trtllm=False, activation=activation
273
        )
274
275

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
276
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
277
            hidden_states=td.hidden_states,
278
279
            gating_output=score,
            topk=topk,
280
            renormalize=False,
281
        )
282

283
284
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
285
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
286
            w2_scale=td.w2_weight_scale,
287
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
288
            a1_scale=td.a1_scale,
289
            a1_gscale=td.a1_scale,
290
            a2_scale=td.a2_scale,
291
            a2_gscale=1.0 / td.a2_scale,
292
293
294
            per_act_token_quant=False,
        )

295
296
297
298
299
300
301
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
302
            activation=activation,
303
304
305
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
306
            quant_config=quant_config,
307
308
309
310
        )

        td.layer.dp_size = 1

311
312
313
314
315
316
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

317
318
319
320
321
322
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
323
            num_logical_experts=e,
324
325
326
327
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
328
            is_act_and_mul=activation.is_gated,
329
330
331
            routing_method=RoutingMethodType.TopK,
        )

332
        kernel = mk.FusedMoEModularKernel(
333
            MoEPrepareAndFinalizeNoEP(),
334
            FlashInferExperts(
335
                moe_config=moe_config,
336
337
                quant_config=quant_config,
            ),
338
            inplace=False,
339
340
341
        )

        flashinfer_cutlass_output = kernel(
342
            td.hidden_states,
343
344
            td.layer.w13_weight,
            td.layer.w2_weight,
345
346
            topk_weights,
            topk_ids,
347
            activation=activation,
348
349
350
351
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
352
353
354
355
356
357
358

        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_cutlass_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
359
        )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477


def test_flashinfer_blockscale_fp8_none_expert_group(monkeypatch):
    """Test that flashinfer_fused_moe_blockscale_fp8 handles num_expert_group=None.

    Regression test for https://github.com/vllm-project/vllm/issues/34477
    MiniMax-M2.1 uses sigmoid scoring with e_score_correction_bias but no
    grouped top-k, resulting in num_expert_group=None. This triggered a crash
    in the flashinfer kernel when DeepSeekV3 routing was selected.
    """
    if not current_platform.has_device_capability(100):
        pytest.skip("Test requires SM >= 100 (Blackwell)")

    import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe  # noqa: E501, F401
    from tests.kernels.quant_utils import native_per_token_group_quant_fp8

    set_random_seed(7)
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")

    e = 16  # num_experts (must be divisible by 4)
    topk = 6  # top_k > 1 triggers DeepSeekV3 routing with sigmoid
    m, n, k = 10, 4096, 5120
    block_shape = [128, 128]
    block_k = block_shape[1]

    with set_current_vllm_config(vllm_config):
        # Create BF16 hidden states
        x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10

        # Create FP8 block-scale quantized weights
        w13_bf16 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
        w2_bf16 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10

        # Quantize weights per-block to FP8
        w13_fp8_list, w13_scale_list = [], []
        w2_fp8_list, w2_scale_list = [], []
        for i in range(e):
            wq, ws = native_per_token_group_quant_fp8(w13_bf16[i], block_k)
            w13_fp8_list.append(wq)
            w13_scale_list.append(ws)

            wq, ws = native_per_token_group_quant_fp8(w2_bf16[i], block_k)
            w2_fp8_list.append(wq)
            w2_scale_list.append(ws)

        w13_fp8 = torch.stack(w13_fp8_list)
        w13_scale = torch.stack(w13_scale_list)
        w2_fp8 = torch.stack(w2_fp8_list)
        w2_scale = torch.stack(w2_scale_list)

        # DeepSeekV3 routing uses float32 logits + optional bias
        routing_logits = torch.randn((m, e), device="cuda", dtype=torch.float32)
        routing_bias = torch.randn(e, device="cuda", dtype=torch.float32)

        # This should NOT crash with num_expert_group=None
        output = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
            routing_logits=routing_logits,
            routing_bias=routing_bias,
            x=x,
            w13_weight=w13_fp8,
            w13_weight_scale_inv=w13_scale,
            w2_weight=w2_fp8,
            w2_weight_scale_inv=w2_scale,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
            intermediate_size=n,
            expert_offset=0,
            local_num_experts=e,
            block_shape=block_shape,
            routing_method_type=RoutingMethodType.DeepSeekV3,
            routed_scaling=1.0,
        )

        assert output is not None
        assert output.shape == (m, k)