test_flashinfer.py 13.3 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
from vllm.model_executor.layers.fused_moe.config import (
15
16
    FusedMoEConfig,
    FusedMoEParallelConfig,
17
    FusedMoEQuantConfig,
18
    RoutingMethodType,
19
20
    fp8_w8a8_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
22
    TrtLlmFp8ExpertsMonolithic,
23
)
24
25
26
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
29
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
30
31
32
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
33
34
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
35
from vllm.utils.torch_utils import set_random_seed
36
37
38
39
40
41
42
43

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
44

45
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
46
    90
47
48
):
    pytest.skip(
49
        "Supported for sm >= 90",
50
51
        allow_module_level=True,
    )
52
53
54
55
56
57
58
59
60
61
62
63
64

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

65
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
66
67
68
69
70
71
72
73
74


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
75
76
        if a_global_sf.numel() == 1:
            a_global_sf = a_global_sf.view(1, 1)
77
78
79
80
81
82
83
84
85
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


86
87
88
89
90
91
92
93
94
95
96
97
98
99
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
    close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
    match_ratio = close.float().mean()
    assert match_ratio >= percent, (
        f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
    )

    mismatch_percent = 1.0 - match_ratio.item()
    assert mismatch_percent <= 1 - percent, (
        f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
        f"{1 - percent:.4f}"
    )


100
101
102
103
104
105
106
107
108
109
110
111
@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
112
    def make_moe_tensors_8bit(
113
114
115
116
117
118
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
119
        topk: int = 1,
120
    ) -> "TestData":
121
        is_gated = activation.is_gated
122

123
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
124
125
126
127
128
        w13 = (
            torch.randn(
                (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
            )
            / 10
129
        )
130
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
131
132
133

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
134
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
135
136
137
138
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
139
        layer.orig_dtype = torch.bfloat16
140
141
142
143
144
145
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
146
        layer.activation = activation
147
        # Setup dummy config.
148
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
149
150

        # flashinfer expects swapped rows for w13
151
152
        if is_gated:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
153
        if is_trtllm:
154
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
155
                layer.w13_weight, layer.w2_weight, is_gated
156
            )
157
        layer.custom_routing_function = Llama4MoE.custom_routing_function
158
159
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
160
161
162
163
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        layer.moe = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
            num_logical_experts=e,
            moe_parallel_config=layer.moe_parallel_config,
            in_dtype=hidden_states.dtype,
            is_act_and_mul=is_gated,
            routing_method=layer.routing_method_type,
            activation=activation,
            device=w13_quantized.device,
        )

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
194
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
195
196
197
198
199
200
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
201
    activation: MoEActivation,
202
203
    monkeypatch,
):
204
205
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
206
    set_random_seed(7)
207
    with set_current_vllm_config(vllm_config):
208
209
210
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, is_trtllm=True, activation=activation
        )
211
212

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
213
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
214
            hidden_states=td.hidden_states,
215
216
            gating_output=score,
            topk=topk,
217
            renormalize=False,
218
        )
219

220
221
222
223
224
225
226
227
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

228
229
230
231
232
233
234
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
235
            activation=activation,
236
237
238
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
239
            quant_config=quant_config,
240
241
        )

242
243
244
245
246
247
248
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=td.layer.moe,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=True,
            ),
249
            TrtLlmFp8ExpertsMonolithic(
250
251
252
253
254
255
                moe_config=td.layer.moe,
                quant_config=quant_config,
            ),
        )

        flashinfer_output = kernel.apply_monolithic(
256
            hidden_states=td.hidden_states,
257
258
            w1=td.layer.w13_weight,
            w2=td.layer.w2_weight,
259
            router_logits=score,
260
            activation=activation,
261
            global_num_experts=e,
262
            expert_map=None,
263
            apply_router_weight_on_input=True,
264
            routed_scaling_factor=1.0,
265
        )
266

267
268
269
270
271
272
273
        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
        )
274
275
276
277
278


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
279
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
280
281
282
283
284
285
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
286
    activation: MoEActivation,
287
    monkeypatch,
288
    workspace_init,
289
):
290
    set_random_seed(7)
291
    with set_current_vllm_config(vllm_config):
292
        td = TestData.make_moe_tensors_8bit(
293
            m, k, n, e, is_trtllm=False, activation=activation
294
        )
295
296

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
297
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
298
            hidden_states=td.hidden_states,
299
300
            gating_output=score,
            topk=topk,
301
            renormalize=False,
302
        )
303

304
305
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
306
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
307
            w2_scale=td.w2_weight_scale,
308
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
309
            a1_scale=td.a1_scale,
310
            a1_gscale=td.a1_scale,
311
            a2_scale=td.a2_scale,
312
            a2_gscale=1.0 / td.a2_scale,
313
314
315
            per_act_token_quant=False,
        )

316
317
318
319
320
321
322
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
323
            activation=activation,
324
325
326
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
327
            quant_config=quant_config,
328
329
330
331
        )

        td.layer.dp_size = 1

332
333
334
335
336
337
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

338
339
340
341
342
343
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
344
            num_logical_experts=e,
345
346
347
348
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
349
            is_act_and_mul=activation.is_gated,
350
351
352
            routing_method=RoutingMethodType.TopK,
        )

353
354
355
356
357
358
359
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
360
            FlashInferExperts(
361
                moe_config=moe_config,
362
363
                quant_config=quant_config,
            ),
364
            inplace=False,
365
366
        )

367
        flashinfer_cutlass_output = kernel.apply(
368
            td.hidden_states,
369
370
            td.layer.w13_weight,
            td.layer.w2_weight,
371
372
            topk_weights,
            topk_ids,
373
            activation=activation,
374
375
376
377
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
378
379
380
381
382
383
384

        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_cutlass_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
385
        )
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts