test_flashinfer.py 13.5 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
12
13
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
14
from vllm.model_executor.layers.fused_moe.config import (
15
16
    FusedMoEConfig,
    FusedMoEParallelConfig,
17
    FusedMoEQuantConfig,
18
    RoutingMethodType,
19
20
    fp8_w8a8_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import (
22
    TrtLlmFp8ExpertsMonolithic,
23
)
24
25
26
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
27
28
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
29
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
30
31
32
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
33
34
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
35
from vllm.utils.math_utils import next_power_of_2
36
from vllm.utils.torch_utils import set_random_seed
37
38
39
40
41
42
43
44

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
45

46
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
47
    90
48
49
):
    pytest.skip(
50
        "Supported for sm >= 90",
51
52
        allow_module_level=True,
    )
53
54
55
56
57
58
59
60
61
62
63
64
65

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

66
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
67
68
69
70
71
72
73
74
75


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
76
77
        if a_global_sf.numel() == 1:
            a_global_sf = a_global_sf.view(1, 1)
78
79
80
81
82
83
84
85
86
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


87
88
89
90
91
92
93
94
95
96
97
98
99
100
def check_accuracy(ref_output, actual_output, atol=0.1, rtol=0.85, percent=0.925):
    close = torch.isclose(ref_output, actual_output, atol=atol, rtol=rtol)
    match_ratio = close.float().mean()
    assert match_ratio >= percent, (
        f"Match ratio {match_ratio:.4f} is below the threshold {percent:.4f}"
    )

    mismatch_percent = 1.0 - match_ratio.item()
    assert mismatch_percent <= 1 - percent, (
        f"Mismatch percentage {mismatch_percent:.4f} is above the threshold "
        f"{1 - percent:.4f}"
    )


101
102
103
104
105
106
107
108
109
110
111
112
@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
113
    def make_moe_tensors_8bit(
114
115
116
117
118
119
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
120
        topk: int = 1,
121
    ) -> "TestData":
122
        is_gated = activation.is_gated
123

124
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
125
126
127
128
129
        w13 = (
            torch.randn(
                (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
            )
            / 10
130
        )
131
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
132
133
134

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
135
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
136
137
138
139
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
140
        layer.orig_dtype = torch.bfloat16
141
142
143
144
145
146
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
147
        layer.activation = activation
148
        # Setup dummy config.
149
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
150
151

        # flashinfer expects swapped rows for w13
152
153
        if is_gated:
            layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
154
        if is_trtllm:
155
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
156
                layer.w13_weight, layer.w2_weight, is_gated
157
            )
158
        layer.custom_routing_function = Llama4MoE.custom_routing_function
159
160
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
161
162
163
164
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

165
166
167
168
169
170
171
172
173
174
175
176
177
        layer.moe = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
            num_logical_experts=e,
            moe_parallel_config=layer.moe_parallel_config,
            in_dtype=hidden_states.dtype,
            is_act_and_mul=is_gated,
            routing_method=layer.routing_method_type,
            activation=activation,
            device=w13_quantized.device,
178
            max_num_tokens=next_power_of_2(m),
179
180
        )

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
196
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
197
198
199
200
201
202
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
203
    activation: MoEActivation,
204
205
    monkeypatch,
):
206
207
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
208
    set_random_seed(7)
209
    with set_current_vllm_config(vllm_config):
210
211
212
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, is_trtllm=True, activation=activation
        )
213
214

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
215
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
216
            hidden_states=td.hidden_states,
217
218
            gating_output=score,
            topk=topk,
219
            renormalize=False,
220
        )
221

222
223
224
225
226
227
228
229
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

230
231
232
233
234
235
236
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
237
            activation=activation,
238
239
240
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
241
            quant_config=quant_config,
242
243
        )

244
245
246
247
248
249
250
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=td.layer.moe,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=True,
            ),
251
            TrtLlmFp8ExpertsMonolithic(
252
253
254
255
256
257
                moe_config=td.layer.moe,
                quant_config=quant_config,
            ),
        )

        flashinfer_output = kernel.apply_monolithic(
258
            hidden_states=td.hidden_states,
259
260
            w1=td.layer.w13_weight,
            w2=td.layer.w2_weight,
261
            router_logits=score,
262
            activation=activation,
263
            global_num_experts=e,
264
            expert_map=None,
265
            apply_router_weight_on_input=True,
266
            routed_scaling_factor=1.0,
267
        )
268

269
270
271
272
273
274
275
        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
        )
276
277
278
279
280


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
281
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
282
283
284
285
286
287
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
288
    activation: MoEActivation,
289
    monkeypatch,
290
    workspace_init,
291
):
292
    set_random_seed(7)
293
    with set_current_vllm_config(vllm_config):
294
        td = TestData.make_moe_tensors_8bit(
295
            m, k, n, e, is_trtllm=False, activation=activation
296
        )
297
298

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
299
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
300
            hidden_states=td.hidden_states,
301
302
            gating_output=score,
            topk=topk,
303
            renormalize=False,
304
        )
305

306
307
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
308
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
309
            w2_scale=td.w2_weight_scale,
310
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
311
            a1_scale=td.a1_scale,
312
            a1_gscale=td.a1_scale,
313
            a2_scale=td.a2_scale,
314
            a2_gscale=1.0 / td.a2_scale,
315
316
317
            per_act_token_quant=False,
        )

318
319
320
321
322
323
324
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
325
            activation=activation,
326
327
328
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
329
            quant_config=quant_config,
330
331
332
333
        )

        td.layer.dp_size = 1

334
335
336
337
338
339
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

340
341
342
343
344
345
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
346
            num_logical_experts=e,
347
348
349
350
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
351
            is_act_and_mul=activation.is_gated,
352
            routing_method=RoutingMethodType.TopK,
353
            max_num_tokens=next_power_of_2(m),
354
355
        )

356
357
358
359
360
361
362
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
363
            FlashInferExperts(
364
                moe_config=moe_config,
365
366
                quant_config=quant_config,
            ),
367
            inplace=False,
368
369
        )

370
        flashinfer_cutlass_output = kernel.apply(
371
            td.hidden_states,
372
373
            td.layer.w13_weight,
            td.layer.w2_weight,
374
375
            topk_weights,
            topk_ids,
376
            activation=activation,
377
378
379
380
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
381
382
383
384
385
386
387

        check_accuracy(
            ref_output=output,
            actual_output=flashinfer_cutlass_output,
            atol=0.1,
            rtol=0.85,
            percent=0.925,
388
        )
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts