test_flashinfer.py 8.8 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
9
10
11
12
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
13
14
15
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
16
17
18
19
20
21
22
    apply_flashinfer_per_tensor_scale_fp8,
    flashinfer_cutlass_moe_fp8,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
23
24
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
25
26
27
28
29
30
31
32

try:
    from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
except ImportError:
    if current_platform.is_rocm():
        pytest.skip(
            "flashinfer not supported for vLLM on ROCm", allow_module_level=True
        )
33

34
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
35
    90
36
37
):
    pytest.skip(
38
        "Supported for sm >= 90",
39
40
        allow_module_level=True,
    )
41
42
43
44
45
46
47
48
49
50
51
52
53

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

54
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
86
    def make_moe_tensors_8bit(
87
        m: int, k: int, n: int, e: int, reorder: bool, activation: str = "silu"
88
    ) -> "TestData":
89
90
        is_gated = activation != "relu2_no_mul"

91
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
92
93
94
        w13 = torch.randn(
            (e, (2 * n) if is_gated else n, k), device="cuda", dtype=torch.bfloat16
        )
95
96
97
98
99
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
100
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale

        register_moe_scaling_factors(layer)

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
        if reorder:
117
            rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        layer.custom_routing_function = Llama4MoE.custom_routing_function
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
146
147
    if not current_platform.has_device_capability(100):
        pytest.skip("Test is only supported for sm >= 100")
148
149
150
151
152
153
    current_platform.seed_everything(7)
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
        td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
XuruiYang's avatar
XuruiYang committed
154
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
155
156
157
158
159
160
            hidden_states=td.hidden_states,
            router_logits=score,
            use_grouped_topk=False,
            top_k=topk,
            renormalize=False,
            custom_routing_function=Llama4MoE.custom_routing_function,
161
162
            scoring_func="softmax",
        )
163

164
165
166
167
168
169
170
171
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

172
173
174
175
176
177
178
179
180
181
182
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
183
            quant_config=quant_config,
184
185
186
187
188
189
190
191
192
193
194
        )

        flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
195
196
            apply_router_weight_on_input=True,
        )
197

198
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
199
200
201
202
203


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
204
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
205
206
207
208
209
210
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
211
    activation: str,
212
213
214
215
216
    monkeypatch,
):
    current_platform.seed_everything(7)
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
217
218
219
        td = TestData.make_moe_tensors_8bit(
            m, k, n, e, reorder=False, activation=activation
        )
220
221

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
XuruiYang's avatar
XuruiYang committed
222
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
223
224
225
226
227
228
            hidden_states=td.hidden_states,
            router_logits=score,
            use_grouped_topk=False,
            top_k=topk,
            renormalize=False,
            custom_routing_function=Llama4MoE.custom_routing_function,
229
230
            scoring_func="softmax",
        )
231

232
233
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
234
            g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(),
235
            w2_scale=td.w2_weight_scale,
236
            g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(),
237
            a1_scale=td.a1_scale,
238
            a1_gscale=td.a1_scale,
239
            a2_scale=td.a2_scale,
240
            a2_gscale=1.0 / td.a2_scale,
241
242
243
            per_act_token_quant=False,
        )

244
245
246
247
248
249
250
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
251
            activation=activation,
252
253
254
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
255
            quant_config=quant_config,
256
257
258
259
        )

        td.layer.dp_size = 1

260
261
262
263
264
265
        def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
            return quant_config

        td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
        td.layer.quant_method = td.layer

266
267
268
269
270
        flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
            td.hidden_states,
            td.layer,
            topk_weights,
            topk_ids,
271
            activation=activation,
272
273
274
275
276
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )

277
278
279
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )