test_flashinfer.py 8.09 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass

import pytest
import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
9
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
10
11
12
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
13
14
15
16
17
18
19
    apply_flashinfer_per_tensor_scale_fp8,
    flashinfer_cutlass_moe_fp8,
    register_moe_scaling_factors,
    rotate_flashinfer_fp8_moe_weights,
    swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
20
21
22
23
from vllm.model_executor.models.llama4 import Llama4MoE
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

24
25
26
27
28
29
30
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
    100
):
    pytest.skip(
        "Requires flashinfer_cutlass_fused_moe and nvfp4 support",
        allow_module_level=True,
    )
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

NUM_EXPERTS = [16]
TOP_KS = [1]

MNK_FACTORS = [
    (256, 8192, 5120),
    (256, 4096, 5120),
    (127, 8192, 5120),
    (127, 4096, 5120),
    (10, 8192, 5120),
    (10, 4096, 5120),
    (1, 8192, 5120),
    (1, 4096, 5120),
]

46
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


def quant_fp8_per_tensor_batches(a):
    num_batches = a.size(0)
    a_quant = []
    a_scales = []

    for i in range(num_batches):
        a_fp8, a_global_sf = input_to_float8(a[i])
        a_global_sf = 1.0 / a_global_sf
        a_quant.append(a_fp8)
        a_scales.append(a_global_sf)

    result_a_quant = torch.stack(a_quant)
    result_a_scales = torch.stack(a_scales)

    return result_a_quant, result_a_scales


@dataclass
class TestData:
    hidden_states: torch.Tensor
    w13_quantized: torch.Tensor
    w2_quantized: torch.Tensor
    a1_scale: torch.Tensor
    a2_scale: torch.Tensor
    w13_weight_scale: torch.Tensor
    w2_weight_scale: torch.Tensor
    layer: torch.nn.Module

    @staticmethod
80
81
82
83
    def make_moe_tensors_8bit(
        m: int, k: int, n: int, e: int, reorder: bool
    ) -> "TestData":
        hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
84
85
86
87
88
89
        w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
        w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)

        # Scale to fp8
        _, a1_scale = input_to_float8(hidden_states)
        a1_scale = 1.0 / a1_scale
90
        a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(dtype=torch.float32)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
        w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)

        layer = torch.nn.Module()
        layer.w13_weight = w13_quantized.clone()
        layer.w2_weight = w2_quantized.clone()
        layer.w13_input_scale = a1_scale
        layer.w2_input_scale = a2_scale
        layer.w13_weight_scale = w13_weight_scale
        layer.w2_weight_scale = w2_weight_scale

        register_moe_scaling_factors(layer)

        # flashinfer expects swapped rows for w13
        layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
        if reorder:
107
            rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        layer.custom_routing_function = Llama4MoE.custom_routing_function
        layer.intermediate_size_per_partition = n
        layer.ep_rank = 0
        layer.local_num_experts = e

        return TestData(
            hidden_states=hidden_states,
            w13_quantized=w13_quantized,
            w2_quantized=w2_quantized,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
            w13_weight_scale=w13_weight_scale,
            w2_weight_scale=w2_weight_scale,
            layer=layer,
        )


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_per_tensor_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
    current_platform.seed_everything(7)
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
        td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
XuruiYang's avatar
XuruiYang committed
142
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
143
144
145
146
147
148
            hidden_states=td.hidden_states,
            router_logits=score,
            use_grouped_topk=False,
            top_k=topk,
            renormalize=False,
            custom_routing_function=Llama4MoE.custom_routing_function,
149
150
            scoring_func="softmax",
        )
151

152
153
154
155
156
157
158
159
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

160
161
162
163
164
165
166
167
168
169
170
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
171
            quant_config=quant_config,
172
173
174
175
176
177
178
179
180
181
182
        )

        flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
            layer=td.layer,
            hidden_states=td.hidden_states,
            router_logits=score,
            routing_bias=None,
            global_num_experts=e,
            top_k=topk,
            num_expert_group=None,
            topk_group=None,
183
184
            apply_router_weight_on_input=True,
        )
185

186
        torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208


@pytest.mark.skip(
    "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
def test_flashinfer_cutlass_moe_fp8_no_graph(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    monkeypatch,
):
    current_platform.seed_everything(7)
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
    with set_current_vllm_config(vllm_config):
        td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)

        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
XuruiYang's avatar
XuruiYang committed
209
        topk_weights, topk_ids, _ = FusedMoE.select_experts(
210
211
212
213
214
215
            hidden_states=td.hidden_states,
            router_logits=score,
            use_grouped_topk=False,
            top_k=topk,
            renormalize=False,
            custom_routing_function=Llama4MoE.custom_routing_function,
216
217
            scoring_func="softmax",
        )
218

219
220
221
222
223
224
225
226
        quant_config = fp8_w8a8_moe_quant_config(
            w1_scale=td.w13_weight_scale,
            w2_scale=td.w2_weight_scale,
            a1_scale=td.a1_scale,
            a2_scale=td.a2_scale,
            per_act_token_quant=False,
        )

227
228
229
230
231
232
233
234
235
236
237
        output = fused_experts(
            td.hidden_states,
            td.w13_quantized,
            td.w2_quantized,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            inplace=False,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
238
            quant_config=quant_config,
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        )

        td.layer.dp_size = 1

        flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
            td.hidden_states,
            td.layer,
            topk_weights,
            topk_ids,
            activation="silu",
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=True,
        )

254
255
256
        torch.testing.assert_close(
            output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
        )