test_flashinfer.py 12.6 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
from vllm.model_executor.layers.fused_moe.config import (
12
13
    FusedMoEConfig,
    FusedMoEParallelConfig,
14
    FusedMoEQuantConfig,
15
    RoutingMethodType,
16
17
    fp8_w8a8_moe_quant_config,
)
18
19
20
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
21
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
22
23
24
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
25
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
26
    apply_fi_trtllm_fp8_per_tensor_moe,
27
    register_scales_for_trtllm_fp8_per_tensor_moe,
28
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
29
30
31
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
32
33
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
34
from vllm.utils.torch_utils import set_random_seed
35
36
37
38
39
40
41
42

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
43

44
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
45
    90
46
47
):
    pytest.skip(
48
        "Supported for sm >= 90",
49
50
        allow_module_level=True,
    )
51
52
53
54
55
56
57
58
59
60
61
62
63

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

64
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
65
66
67
68
69
70
71
72
73


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
74
75
        if a_global_sf.numel() == 1:
            a_global_sf = a_global_sf.view(1, 1)
76
77
78
79
80
81
82
83
84
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


85
86
87
88
89
90
91
92
93
94
95
96
97
98
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
    close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
    match_ratio = close.float().mean()
    assert match_ratio >= percent, (
        f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
    )

    mismatch_percent = 1.0 - match_ratio.item()
    assert mismatch_percent <= 1 - percent, (
        f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
        f"{1 - percent:.4f}"
    )


99
100
101
102
103
104
105
106
107
108
109
110
@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
111
    def make_moe_tensors_8bit(
112
113
114
115
116
117
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
118
    ) -> "TestData":
119
        is_gated = activation.is_gated
120

121
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
122
123
124
125
126
        w13 = (
            torch.randn(
                (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
            )
            / 10
127
        )
128
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
129
130
131

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
132
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
133
134
135
136
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
137
        layer.orig_dtype = torch.bfloat16
138
139
140
141
142
143
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
144
        layer.activation = activation
145
        # Setup dummy config.
146
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
147
148

        # flashinfer expects swapped rows for w13
149
150
        if is_gated:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
151
        if is_trtllm:
152
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
153
                layer.w13_weight, layer.w2_weight, is_gated
154
            )
155
156
157
158
159
160
161
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
162
        layer.custom_routing_function = Llama4MoE.custom_routing_function
163
164
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
184
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
185
186
187
188
189
190
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
191
    activation: MoEActivation,
192
193
    monkeypatch,
):
194
195
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
196
    set_random_seed(7)
197
198
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
199
200
201
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, is_trtllm=True, activation=activation
        )
202
203

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
204
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
205
            hidden_states=td.hidden_states,
206
207
            gating_output=score,
            topk=topk,
208
            renormalize=False,
209
        )
210

211
212
213
214
215
216
217
218
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

219
220
221
222
223
224
225
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
226
            activation=activation,
227
228
229
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
230
            quant_config=quant_config,
231
232
        )

233
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
234
235
236
237
238
239
240
241
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
242
243
            apply_router_weight_on_input=True,
        )
244

245
246
247
248
249
250
251
        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
        )
252
253
254
255
256


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
257
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
258
259
260
261
262
263
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
264
    activation: MoEActivation,
265
    monkeypatch,
266
    workspace_init,
267
):
268
    set_random_seed(7)
269
270
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
271
        td = TestData.make_moe_tensors_8bit(
272
            m, k, n, e, is_trtllm=False, activation=activation
273
        )
274
275

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
276
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
277
            hidden_states=td.hidden_states,
278
279
            gating_output=score,
            topk=topk,
280
            renormalize=False,
281
        )
282

283
284
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
285
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
286
            w2_scale=td.w2_weight_scale,
287
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
288
            a1_scale=td.a1_scale,
289
            a1_gscale=td.a1_scale,
290
            a2_scale=td.a2_scale,
291
            a2_gscale=1.0 / td.a2_scale,
292
293
294
            per_act_token_quant=False,
        )

295
296
297
298
299
300
301
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
302
            activation=activation,
303
304
305
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
306
            quant_config=quant_config,
307
308
309
310
        )

        td.layer.dp_size = 1

311
312
313
314
315
316
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

317
318
319
320
321
322
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
323
            num_logical_experts=e,
324
325
326
327
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
328
            is_act_and_mul=activation.is_gated,
329
330
331
            routing_method=RoutingMethodType.TopK,
        )

332
        kernel = mk.FusedMoEModularKernel(
333
            MoEPrepareAndFinalizeNoEP(),
334
            FlashInferExperts(
335
                moe_config=moe_config,
336
337
                quant_config=quant_config,
            ),
338
            inplace=False,
339
340
341
        )

        flashinfer_cutlass_output = kernel(
342
            td.hidden_states,
343
344
            td.layer.w13_weight,
            td.layer.w2_weight,
345
346
            topk_weights,
            topk_ids,
347
            activation=activation,
348
349
350
351
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
352
353
354
355
356
357
358

        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_cutlass_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
359
        )
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts