test_flashinfer.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
from vllm.model_executor.layers.fused_moe.config import (
15
16
    FusedMoEConfig,
    FusedMoEParallelConfig,
17
    FusedMoEQuantConfig,
18
    RoutingMethodType,
19
20
    fp8_w8a8_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
22
    TrtLlmFp8ExpertsMonolithic,
23
)
24
25
26
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
29
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
30
31
32
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
33
34
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
35
from vllm.utils.torch_utils import set_random_seed
36
37
38
39
40
41
42
43

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
44

45
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
46
    90
47
48
):
    pytest.skip(
49
        "Supported for sm >= 90",
50
51
        allow_module_level=True,
    )
52
53
54
55
56
57
58
59
60
61
62
63
64

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

65
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
66
67
68
69
70
71
72
73
74


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
75
76
        if a_global_sf.numel() == 1:
            a_global_sf = a_global_sf.view(1, 1)
77
78
79
80
81
82
83
84
85
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


86
87
88
89
90
91
92
93
94
95
96
97
98
99
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
    close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
    match_ratio = close.float().mean()
    assert match_ratio >= percent, (
        f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
    )

    mismatch_percent = 1.0 - match_ratio.item()
    assert mismatch_percent <= 1 - percent, (
        f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
        f"{1 - percent:.4f}"
    )


100
101
102
103
104
105
106
107
108
109
110
111
@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
112
    def make_moe_tensors_8bit(
113
114
115
116
117
118
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
119
        topk: int = 1,
120
    ) -> "TestData":
121
        is_gated = activation.is_gated
122

123
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
124
125
126
127
128
        w13 = (
            torch.randn(
                (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
            )
            / 10
129
        )
130
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
131
132
133

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
134
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
135
136
137
138
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
139
        layer.orig_dtype = torch.bfloat16
140
141
142
143
144
145
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
146
        layer.activation = activation
147
        # Setup dummy config.
148
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
149
150

        # flashinfer expects swapped rows for w13
151
152
        if is_gated:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
153
        if is_trtllm:
154
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
155
                layer.w13_weight, layer.w2_weight, is_gated
156
            )
157
        layer.custom_routing_function = Llama4MoE.custom_routing_function
158
159
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
160
161
162
163
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        layer.moe = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
            num_logical_experts=e,
            moe_parallel_config=layer.moe_parallel_config,
            in_dtype=hidden_states.dtype,
            is_act_and_mul=is_gated,
            routing_method=layer.routing_method_type,
            activation=activation,
            device=w13_quantized.device,
        )

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
194
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
195
196
197
198
199
200
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
201
    activation: MoEActivation,
202
203
    monkeypatch,
):
204
205
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
206
    set_random_seed(7)
207
208
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
209
210
211
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, is_trtllm=True, activation=activation
        )
212
213

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
214
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
215
            hidden_states=td.hidden_states,
216
217
            gating_output=score,
            topk=topk,
218
            renormalize=False,
219
        )
220

221
222
223
224
225
226
227
228
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

229
230
231
232
233
234
235
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
236
            activation=activation,
237
238
239
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
240
            quant_config=quant_config,
241
242
        )

243
244
245
246
247
248
249
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=td.layer.moe,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=True,
            ),
250
            TrtLlmFp8ExpertsMonolithic(
251
252
253
254
255
256
                moe_config=td.layer.moe,
                quant_config=quant_config,
            ),
        )

        flashinfer_output = kernel.apply_monolithic(
257
            hidden_states=td.hidden_states,
258
259
            w1=td.layer.w13_weight,
            w2=td.layer.w2_weight,
260
            router_logits=score,
261
            activation=activation,
262
            global_num_experts=e,
263
            expert_map=None,
264
            apply_router_weight_on_input=True,
265
            routed_scaling_factor=1.0,
266
        )
267

268
269
270
271
272
273
274
        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
        )
275
276
277
278
279


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
280
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
281
282
283
284
285
286
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
287
    activation: MoEActivation,
288
    monkeypatch,
289
    workspace_init,
290
):
291
    set_random_seed(7)
292
293
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
294
        td = TestData.make_moe_tensors_8bit(
295
            m, k, n, e, is_trtllm=False, activation=activation
296
        )
297
298

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
299
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
300
            hidden_states=td.hidden_states,
301
302
            gating_output=score,
            topk=topk,
303
            renormalize=False,
304
        )
305

306
307
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
308
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
309
            w2_scale=td.w2_weight_scale,
310
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
311
            a1_scale=td.a1_scale,
312
            a1_gscale=td.a1_scale,
313
            a2_scale=td.a2_scale,
314
            a2_gscale=1.0 / td.a2_scale,
315
316
317
            per_act_token_quant=False,
        )

318
319
320
321
322
323
324
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
325
            activation=activation,
326
327
328
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
329
            quant_config=quant_config,
330
331
332
333
        )

        td.layer.dp_size = 1

334
335
336
337
338
339
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

340
341
342
343
344
345
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
346
            num_logical_experts=e,
347
348
349
350
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
351
            is_act_and_mul=activation.is_gated,
352
353
354
            routing_method=RoutingMethodType.TopK,
        )

355
356
357
358
359
360
361
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
362
            FlashInferExperts(
363
                moe_config=moe_config,
364
365
                quant_config=quant_config,
            ),
366
            inplace=False,
367
368
        )

369
        flashinfer_cutlass_output = kernel.apply(
370
            td.hidden_states,
371
372
            td.layer.w13_weight,
            td.layer.w2_weight,
373
374
            topk_weights,
            topk_ids,
375
            activation=activation,
376
377
378
379
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
380
381
382
383
384
385
386

        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_cutlass_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
387
        )
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts