test_flashinfer.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

8
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
11
12
13
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
14
15
16
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
    FlashInferExperts,
)
17
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
18
19
20
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
21
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
22
    apply_fi_trtllm_fp8_per_tensor_moe,
23
    register_scales_for_trtllm_fp8_per_tensor_moe,
24
    rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
25
26
27
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
28
29
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
30
from vllm.utils.torch_utils import set_random_seed
31
32
33
34
35
36
37
38

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
39

40
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
41
    90
42
43
):
    pytest.skip(
44
        "Supported for sm >= 90",
45
46
        allow_module_level=True,
    )
47
48
49
50
51
52
53
54
55
56
57
58
59

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

60
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
92
    def make_moe_tensors_8bit(
93
        m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
94
    ) -> "TestData":
95
96
        is_gated = activation != "relu2_no_mul"

97
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
98
99
100
        w13 = torch.randn(
            (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
        )
101
102
103
104
105
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
106
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
107
108
109
110
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
111
        layer.orig_dtype = torch.bfloat16
112
113
114
115
116
117
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale
118
119
120
121
122
123
        # Setup dummy config.
        layer.moe_parallel_config = mk.FusedMoEParallelConfig(
            tp_size=1,
            pcp_size=1,
            dp_size=1,
            ep_size=1,
124
125
126
127
            tp_rank=0,
            pcp_rank=0,
            dp_rank=0,
            ep_rank=0,
128
129
130
            use_ep=False,
            all2all_backend="naive",
        )
131
132
133

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
134
        if is_trtllm:
135
136
137
            rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
                layer.w13_weight, layer.w2_weight
            )
138
139
140
141
142
143
144
            register_scales_for_trtllm_fp8_per_tensor_moe(
                layer,
                layer.w13_weight_scale,
                layer.w13_input_scale,
                layer.w2_weight_scale,
                layer.w2_input_scale,
            )
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        layer.custom_routing_function = Llama4MoE.custom_routing_function
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
173
174
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
175
    set_random_seed(7)
176
177
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
178
        td = TestData.make_moe_tensors_8bit(m, k, n, e, is_trtllm=True)
179
180

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
181
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
182
            hidden_states=td.hidden_states,
183
184
            gating_output=score,
            topk=topk,
185
            renormalize=False,
186
        )
187

188
189
190
191
192
193
194
195
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

196
197
198
199
200
201
202
203
204
205
206
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
207
            quant_config=quant_config,
208
209
        )

210
        flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
211
212
213
214
215
216
217
218
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
219
220
            apply_router_weight_on_input=True,
        )
221

222
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
223
224
225
226
227


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
228
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
229
230
231
232
233
234
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
235
    activation: str,
236
    monkeypatch,
237
    workspace_init,
238
):
239
    set_random_seed(7)
240
241
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
242
        td = TestData.make_moe_tensors_8bit(
243
            m, k, n, e, is_trtllm=False, activation=activation
244
        )
245
246

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
247
        topk_weights, topk_ids = Llama4MoE.custom_routing_function(
248
            hidden_states=td.hidden_states,
249
250
            gating_output=score,
            topk=topk,
251
            renormalize=False,
252
        )
253

254
255
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
256
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
257
            w2_scale=td.w2_weight_scale,
258
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
259
            a1_scale=td.a1_scale,
260
            a1_gscale=td.a1_scale,
261
            a2_scale=td.a2_scale,
262
            a2_gscale=1.0 / td.a2_scale,
263
264
265
            per_act_token_quant=False,
        )

266
267
268
269
270
271
272
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
273
            activation=activation,
274
275
276
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
277
            quant_config=quant_config,
278
279
280
281
        )

        td.layer.dp_size = 1

282
283
284
285
286
287
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        kernel = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(
                defer_input_quant=quant_config.is_block_quantized
            ),
            FlashInferExperts(
                out_dtype=td.layer.orig_dtype,
                quant_config=quant_config,
                ep_rank=td.layer.moe_parallel_config.ep_rank,
                ep_size=td.layer.moe_parallel_config.ep_size,
                tp_rank=td.layer.moe_parallel_config.tp_rank,
                tp_size=td.layer.moe_parallel_config.tp_size,
                use_dp=False,
                use_deepseek_fp8_block_scale=False,
            ),
        )

        flashinfer_cutlass_output = kernel(
305
            td.hidden_states,
306
307
            td.layer.w13_weight,
            td.layer.w2_weight,
308
309
            topk_weights,
            topk_ids,
310
            inplace=False,
311
            activation=activation,
312
313
314
315
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )
316
317
318
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )