test_flashinfer.py 11.5 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.config import (
11
12
    FusedMoEConfig,
    FusedMoEParallelConfig,
13
    FusedMoEQuantConfig,
14
    RoutingMethodType,
15
16
    fp8_w8a8_moe_quant_config,
)
17
18
19
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
20
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
21
22
23
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
24
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
25
    apply_fi_trtllm_fp8_per_tensor_moe,
26
    register_scales_for_trtllm_fp8_per_tensor_moe,
27
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
28
29
30
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
31
32
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
33
from vllm.utils.torch_utils import set_random_seed
34
35
36
37
38
39
40
41

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
42

43
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
44
    90
45
46
):
    pytest.skip(
47
        "Supported for sm >= 90",
48
49
        allow_module_level=True,
    )
50
51
52
53
54
55
56
57
58
59
60
61
62

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

63
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
95
    def make_moe_tensors_8bit(
96
        m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
97
    ) -> "TestData":
98
99
        is_gated = activation != "relu2_no_mul"

100
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
101
102
103
        w13 = torch.randn(
            (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
        )
104
105
106
107
108
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
109
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
110
111
112
113
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
114
        layer.orig_dtype = torch.bfloat16
115
116
117
118
119
120
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
121
        # Setup dummy config.
122
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
123
124
125

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
126
        if is_trtllm:
127
128
129
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
                layer.w13_weight, layer.w2_weight
            )
130
131
132
133
134
135
136
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
137
        layer.custom_routing_function = Llama4MoE.custom_routing_function
138
139
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
167
168
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
169
    set_random_seed(7)
170
171
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
172
        td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
173
174

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
175
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
176
            hidden_states=td.hidden_states,
177
178
            gating_output=score,
            topk=topk,
179
            renormalize=False,
180
        )
181

182
183
184
185
186
187
188
189
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

190
191
192
193
194
195
196
197
198
199
200
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
201
            quant_config=quant_config,
202
203
        )

204
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
205
206
207
208
209
210
211
212
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
213
214
            apply_router_weight_on_input=True,
        )
215

216
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
217
218
219
220
221


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
222
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
223
224
225
226
227
228
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
229
    activation: str,
230
    monkeypatch,
231
    workspace_init,
232
):
233
    set_random_seed(7)
234
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
235
236
    assert activation in ["silu", "relu2_no_mul"]
    is_act_and_mul = activation == "silu_and_mul"
237
    with set_current_vllm_config(vllm_config):
238
        td = TestData.make_moe_tensors_8bit(
239
            m, k, n, e, is_trtllm=False, activation=activation
240
        )
241
242

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
243
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
244
            hidden_states=td.hidden_states,
245
246
            gating_output=score,
            topk=topk,
247
            renormalize=False,
248
        )
249

250
251
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
252
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
253
            w2_scale=td.w2_weight_scale,
254
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
255
            a1_scale=td.a1_scale,
256
            a1_gscale=td.a1_scale,
257
            a2_scale=td.a2_scale,
258
            a2_gscale=1.0 / td.a2_scale,
259
260
261
            per_act_token_quant=False,
        )

262
263
264
265
266
267
268
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
269
            activation=activation,
270
271
272
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
273
            quant_config=quant_config,
274
275
276
277
        )

        td.layer.dp_size = 1

278
279
280
281
282
283
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

284
285
286
287
288
289
290
291
292
293
294
295
296
297
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
            is_act_and_mul=is_act_and_mul,
            routing_method=RoutingMethodType.TopK,
        )

298
        kernel = mk.FusedMoEModularKernel(
299
            MoEPrepareAndFinalizeNoEP(),
300
            FlashInferExperts(
301
                moe_config=moe_config,
302
303
                quant_config=quant_config,
            ),
304
            inplace=False,
305
306
307
        )

        flashinfer_cutlass_output = kernel(
308
            td.hidden_states,
309
310
            td.layer.w13_weight,
            td.layer.w2_weight,
311
312
            topk_weights,
            topk_ids,
313
            activation=activation,
314
315
316
317
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
318
319
320
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts