test_flashinfer_moe.py 5.74 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

6
from tests.kernels.moe.utils import make_test_quant_config
7
8
9
10
11
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
12
13
14
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
15
from vllm.model_executor.layers.fused_moe import fused_topk
16
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
17
18
19
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
20
21
22
23
24
from vllm.model_executor.layers.fused_moe.config import (
    FusedMoEConfig,
    FusedMoEParallelConfig,
    RoutingMethodType,
)
25
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
26
27
28
    FlashInferExperts,
    is_valid_flashinfer_cutlass_fused_moe,
)
29
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEKernel
30
31
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
32
from vllm.utils.math_utils import next_power_of_2
33
from vllm.utils.torch_utils import set_random_seed
34

35
36
37
38
39
40
41
if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability(
    100
):
    pytest.skip(
        "Requires flashinfer_cutlass_fused_moe and nvfp4 support",
        allow_module_level=True,
    )
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 3072, 1024),
    (2, 3072, 1536),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (64, 2048, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
]


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
58
@pytest.mark.parametrize("dtype", [torch.bfloat16])
59
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
60
@torch.inference_mode()
61
def test_flashinfer_fp4_moe_no_graph(
62
63
64
65
66
67
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
68
    activation: MoEActivation,
69
    workspace_init,
70
):
71
    set_random_seed(7)
72
    with set_current_vllm_config(
73
74
        VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
    ):
75
76
77
        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

        quant_blocksize = 16
78
        is_gated_act = activation.is_gated
79

80
81
82
83
84
85
86
87
        w1_q, w2_q, quant_config = make_test_quant_config(
            e,
            n,
            k,
            in_dtype=dtype,
            quant_dtype="nvfp4",
            block_shape=None,
            per_act_token_quant=False,
88
            make_gate=is_gated_act,
89
        )
90
91

        score = torch.randn((m, e), device="cuda", dtype=dtype)
92
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
93
94
95

        assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)

96
97
98
99
100
101
        moe_config = FusedMoEConfig(
            num_experts=e,
            experts_per_token=topk,
            hidden_dim=k,
            intermediate_size_per_partition=n,
            num_local_experts=e,
102
            num_logical_experts=e,
103
104
105
106
107
108
            activation=activation,
            device="cuda",
            moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
            in_dtype=dtype,
            is_act_and_mul=is_gated_act,
            routing_method=RoutingMethodType.TopK,
109
            max_num_tokens=next_power_of_2(m),
110
111
        )

112
113
114
115
116
117
118
        flashinfer_experts = FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
119
            FlashInferExperts(moe_config=moe_config, quant_config=quant_config),
120
            inplace=False,
121
        )
122

123
        flashinfer_output = flashinfer_experts.apply(
124
125
126
127
128
            hidden_states=a,
            w1=w1_q,
            w2=w2_q,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
129
            activation=activation,
130
131
132
            global_num_experts=e,
            expert_map=None,
            apply_router_weight_on_input=False,
133
134
135
        )

        # Reference check:
136
137
138
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
        ).to(torch.float32)
139
140
        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
        _, m_k = a_fp4.shape
141
142
143
144
145
146
147
148
        a_in_dtype = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=a.dtype,
            device=a.device,
            block_size=quant_blocksize,
        )
149

150
151
152
        w1_d = torch.empty(
            (e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
        )
153
154
155
        w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
156
157
            w1_d[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
158
159
                quant_config.w1_scale[idx],
                (1 / quant_config.g1_alphas[idx]),
160
161
                dtype=dtype,
                device=w1_q.device,
162
163
                block_size=quant_blocksize,
            )
164
165
            w2_d[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
166
167
                quant_config.w2_scale[idx],
                (1 / quant_config.g2_alphas[idx]),
168
169
                dtype=dtype,
                device=w2_q.device,
170
171
                block_size=quant_blocksize,
            )
172

173
174
175
        torch_output = torch_moe(
            a_in_dtype, w1_d, w2_d, score, topk, activation=activation
        )
176

177
178
179
        torch.testing.assert_close(
            torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
        )
180
181
182
183


if __name__ == "__main__":
    test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)