test_flashinfer.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
from vllm.model_executor.layers.fused_moe.config import (
11
12
    FusedMoEConfig,
    FusedMoEParallelConfig,
13
    FusedMoEQuantConfig,
14
    RoutingMethodType,
15
16
    fp8_w8a8_moe_quant_config,
)
17
18
19
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
20
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
21
22
23
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
24
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
25
    apply_fi_trtllm_fp8_per_tensor_moe,
26
    register_scales_for_trtllm_fp8_per_tensor_moe,
27
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
28
29
30
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
31
32
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
33
from vllm.utils.torch_utils import set_random_seed
34
35
36
37
38
39
40
41

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
42

43
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
44
    90
45
46
):
    pytest.skip(
47
        "Supported for sm >= 90",
48
49
        allow_module_level=True,
    )
50
51
52
53
54
55
56
57
58
59
60
61
62

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

63
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
95
    def make_moe_tensors_8bit(
96
        m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
97
    ) -> "TestData":
98
99
        is_gated = activation != "relu2_no_mul"

100
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
101
102
103
        w13 = torch.randn(
            (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
        )
104
105
106
107
108
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
109
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
110
111
112
113
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
114
        layer.orig_dtype = torch.bfloat16
115
116
117
118
119
120
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
121
        # Setup dummy config.
122
        layer.moe_parallel_config = mk.FusedMoEParallelConfig.make_no_parallel()
123
124
125

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
126
        if is_trtllm:
127
128
129
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
                layer.w13_weight, layer.w2_weight
            )
130
131
132
133
134
135
136
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
137
        layer.custom_routing_function = Llama4MoE.custom_routing_function
138
139
        layer.routing_method_type = RoutingMethodType.Llama4
        layer.renormalize = False
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
167
168
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
169
    set_random_seed(7)
170
171
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
172
        td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
173
174

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
175
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
176
            hidden_states=td.hidden_states,
177
178
            gating_output=score,
            topk=topk,
179
            renormalize=False,
180
        )
181

182
183
184
185
186
187
188
189
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

190
191
192
193
194
195
196
197
198
199
200
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
201
            quant_config=quant_config,
202
203
        )

204
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
205
206
207
208
209
210
211
212
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
213
214
            apply_router_weight_on_input=True,
        )
215

216
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
217
218
219
220
221


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
222
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
223
224
225
226
227
228
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
229
    activation: str,
230
    monkeypatch,
231
    workspace_init,
232
):
233
    set_random_seed(7)
234
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
235
236
    assert activation in ["silu", "relu2_no_mul"]
    is_act_and_mul = activation == "silu_and_mul"
237
    with set_current_vllm_config(vllm_config):
238
        td = TestData.make_moe_tensors_8bit(
239
            m, k, n, e, is_trtllm=False, activation=activation
240
        )
241
242

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
243
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
244
            hidden_states=td.hidden_states,
245
246
            gating_output=score,
            topk=topk,
247
            renormalize=False,
248
        )
249

250
251
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
252
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
253
            w2_scale=td.w2_weight_scale,
254
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
255
            a1_scale=td.a1_scale,
256
            a1_gscale=td.a1_scale,
257
            a2_scale=td.a2_scale,
258
            a2_gscale=1.0 / td.a2_scale,
259
260
261
            per_act_token_quant=False,
        )

262
263
264
265
266
267
268
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
269
            activation=activation,
270
271
272
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
273
            quant_config=quant_config,
274
275
276
277
        )

        td.layer.dp_size = 1

278
279
280
281
282
283
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

284
285
286
287
288
289
290
291
292
293
294
295
296
297
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=torch.bfloat16,
            is_act_and_mul=is_act_and_mul,
            routing_method=RoutingMethodType.TopK,
        )

298
        kernel = mk.FusedMoEModularKernel(
299
            MoEPrepareAndFinalizeNoEP(),
300
            FlashInferExperts(
301
                moe_config=moe_config,
302
303
304
305
306
                quant_config=quant_config,
            ),
        )

        flashinfer_cutlass_output = kernel(
307
            td.hidden_states,
308
309
            td.layer.w13_weight,
            td.layer.w2_weight,
310
311
            topk_weights,
            topk_ids,
312
            inplace=False,
313
            activation=activation,
314
315
316
317
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
318
319
320
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )