test_flashinfer.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
11
from vllm.model_executor.layers.fused_moe.config import (
12
13
    FusedMoEConfig,
    FusedMoEParallelConfig,
14
    FusedMoEQuantConfig,
15
    RoutingMethodType,
16
17
    fp8_w8a8_moe_quant_config,
)
18
19
20
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
21
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
22
23
24
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
25
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
26
    apply_fi_trtllm_fp8_per_tensor_moe,
27
    register_scales_for_trtllm_fp8_per_tensor_moe,
28
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
29
30
31
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
32
33
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
34
from vllm.utils.torch_utils import set_random_seed
35
36
37
38
39
40
41
42

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
43

44
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
45
    90
46
47
):
    pytest.skip(
48
        "Supported for sm >= 90",
49
50
        allow_module_level=True,
    )
51
52
53
54
55
56
57
58
59
60
61
62
63

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

64
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
96
    def make_moe_tensors_8bit(
97
98
99
100
101
102
        m: int,
        k: int,
        n: int,
        e: int,
        is_trtllm: bool,
        activation: MoEActivation = MoEActivation.SILU,
103
    ) -> "TestData":
104
        is_gated = activation.is_gated
105

106
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
107
108
109
        w13 = torch.randn(
            (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
        )
110
111
112
113
114
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
115
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
116
117
118
119
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
120
        layer.orig_dtype = torch.bfloat16
121
122
123
124
125
126
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
127
        # Setup dummy config.
128
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
129
130
131

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
132
        if is_trtllm:
133
134
135
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
                layer.w13_weight, layer.w2_weight
            )
136
137
138
139
140
141
142
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
143
        layer.custom_routing_function = Llama4MoE.custom_routing_function
144
145
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
173
174
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
175
    set_random_seed(7)
176
177
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
178
        td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
179
180

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
181
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
182
            hidden_states=td.hidden_states,
183
184
            gating_output=score,
            topk=topk,
185
            renormalize=False,
186
        )
187

188
189
190
191
192
193
194
195
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

196
197
198
199
200
201
202
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
203
            activation=MoEActivation.SILU,
204
205
206
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
207
            quant_config=quant_config,
208
209
        )

210
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
211
212
213
214
215
216
217
218
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
219
220
            apply_router_weight_on_input=True,
        )
221

222
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
223
224
225
226
227


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
228
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
229
230
231
232
233
234
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
235
    activation: MoEActivation,
236
    monkeypatch,
237
    workspace_init,
238
):
239
    set_random_seed(7)
240
241
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
242
        td = TestData.make_moe_tensors_8bit(
243
            m, k, n, e, is_trtllm=False, activation=activation
244
        )
245
246

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
247
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
248
            hidden_states=td.hidden_states,
249
250
            gating_output=score,
            topk=topk,
251
            renormalize=False,
252
        )
253

254
255
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
256
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
257
            w2_scale=td.w2_weight_scale,
258
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
259
            a1_scale=td.a1_scale,
260
            a1_gscale=td.a1_scale,
261
            a2_scale=td.a2_scale,
262
            a2_gscale=1.0 / td.a2_scale,
263
264
265
            per_act_token_quant=False,
        )

266
267
268
269
270
271
272
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
273
            activation=activation,
274
275
276
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
277
            quant_config=quant_config,
278
279
280
281
        )

        td.layer.dp_size = 1

282
283
284
285
286
287
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

288
289
290
291
292
293
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
294
            num_logical_experts=e,
295
296
297
298
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
299
            is_act_and_mul=activation.is_gated,
300
301
302
            routing_method=RoutingMethodType.TopK,
        )

303
        kernel = mk.FusedMoEModularKernel(
304
            MoEPrepareAndFinalizeNoEP(),
305
            FlashInferExperts(
306
                moe_config=moe_config,
307
308
                quant_config=quant_config,
            ),
309
            inplace=False,
310
311
312
        )

        flashinfer_cutlass_output = kernel(
313
            td.hidden_states,
314
315
            td.layer.w13_weight,
            td.layer.w2_weight,
316
317
            topk_weights,
            topk_ids,
318
            activation=activation,
319
320
321
322
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
323
324
325
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366


@pytest.mark.parametrize(
    "num_experts,intermediate,hidden",
    [
        (8, 2048, 1536),
        (64, 4096, 4096),
    ],
)
def test_convert_moe_weights_to_flashinfer_trtllm_block_layout(
    num_experts, intermediate, hidden
):
    from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
        convert_moe_weights_to_flashinfer_trtllm_block_layout,
    )

    w13 = torch.randn(
        (num_experts, 2 * intermediate, hidden), dtype=torch.bfloat16, device="cuda"
    )
    w2 = torch.randn(
        (num_experts, hidden, intermediate), dtype=torch.bfloat16, device="cuda"
    )

    cache: dict[torch.Size, torch.Tensor] = {}
    w13_converted, w2_converted = convert_moe_weights_to_flashinfer_trtllm_block_layout(
        cache, w13, w2
    )

    assert w13_converted.ndim == 4, (
        f"Expected 4D tensor, got shape {w13_converted.shape}"
    )
    assert w2_converted.ndim == 4, f"Expected 4D tensor, got shape {w2_converted.shape}"

    assert w13_converted.numel() == w13.numel(), "W13 element count should be preserved"
    assert w2_converted.numel() == w2.numel(), "W2 element count should be preserved"

    assert w13_converted.dtype == torch.bfloat16
    assert w2_converted.dtype == torch.bfloat16

    assert w13_converted.shape[0] == num_experts
    assert w2_converted.shape[0] == num_experts