test_nvfp4_moe.py 8.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import pytest
import torch

6
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
7
from tests.kernels.moe.utils import make_dummy_moe_config, make_test_weights
8
9
10
11
12
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
13
14
15
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
16
from vllm.model_executor.layers.fused_moe import fused_topk
17
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
18
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
19
20
21
22
23
24
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
    CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
25
from vllm.platforms import current_platform
26
from vllm.utils.torch_utils import set_random_seed
27
28

if not current_platform.has_device_capability(100):
29
30
31
    pytest.skip(
        "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
    )
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (64, 1024, 1024),
    (64, 3072, 1024),
    (64, 2048, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
]


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
48
@pytest.mark.parametrize("dtype", [torch.bfloat16])
49
@torch.inference_mode()
50
def test_cutlass_fp4_moe_no_graph(
51
    m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
52
):
53
    set_random_seed(7)
54
    with set_current_vllm_config(
55
56
        VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
    ):
57
        quant_blocksize = 16
58
59
60

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

61
62
63
64
65
66
67
68
69
70
71
        (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
            make_test_weights(
                e,
                n,
                k,
                in_dtype=dtype,
                quant_dtype="nvfp4",
                block_shape=None,  # use quant_blocksize?
                per_out_ch_quant=False,
            )
        )
72
73

        score = torch.randn((m, e), device="cuda", dtype=dtype)
74
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
75

76
77
        a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
78

79
80
81
82
83
        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

84
85
86
87
88
89
90
91
92
        quant_config = nvfp4_moe_quant_config(
            g1_alphas=(1 / w1_gs),
            g2_alphas=(1 / w2_gs),
            a1_gscale=a1_gs,
            a2_gscale=a2_gs,
            w1_scale=w1_blockscale,
            w2_scale=w2_blockscale,
        )

93
        kernel = mk.FusedMoEModularKernel(
94
            MoEPrepareAndFinalizeNoEP(),
95
            CutlassExpertsFp4(
96
                moe_config=make_dummy_moe_config(),
97
98
                quant_config=quant_config,
            ),
99
            inplace=False,
100
101
102
103
104
105
        )

        cutlass_output = kernel(
            hidden_states=a,
            w1=w1_q,
            w2=w2_q,
106
107
108
109
110
            topk_weights=topk_weights,
            topk_ids=topk_ids,
        )

        # Reference check:
111
112
113
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
        ).to(torch.float32)
114
        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
115

116
117
118
119
120
121
122
123
        a_in_dtype = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=a.dtype,
            device=a.device,
            block_size=quant_blocksize,
        )
124
125
126
127
128

        w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
        w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            w1_d[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2_d[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )
145

146
        torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
147

148
        torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
149
150


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# step3.5-flash uses swiglustep activation (clipped SwiGLU with limit=7.0)
# for MoE layers 43-44. This tests the non-fused activation fallback path
# in run_cutlass_moe_fp4 (apply_moe_activation + separate fp4 quantization).
# Model dims: e=288, topk=8, n=1280 (moe_intermediate_size), k=4096 (hidden)
SWIGLUSTEP_MNK_FACTORS = [
    (2, 1280, 4096),
    (64, 1280, 4096),
    (224, 1280, 4096),
]


@pytest.mark.parametrize("m,n,k", SWIGLUSTEP_MNK_FACTORS)
@pytest.mark.parametrize("e", [64, 288])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_swiglustep(
    m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
):
    set_random_seed(7)
    with set_current_vllm_config(
        VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
    ):
        quant_blocksize = 16

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

        (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
            make_test_weights(
                e,
                n,
                k,
                in_dtype=dtype,
                quant_dtype="nvfp4",
                block_shape=None,
                per_out_ch_quant=False,
            )
        )

        score = torch.randn((m, e), device="cuda", dtype=dtype)
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)

        a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)

        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

        quant_config = nvfp4_moe_quant_config(
            g1_alphas=(1 / w1_gs),
            g2_alphas=(1 / w2_gs),
            a1_gscale=a1_gs,
            a2_gscale=a2_gs,
            w1_scale=w1_blockscale,
            w2_scale=w2_blockscale,
        )

        kernel = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(),
            CutlassExpertsFp4(
                moe_config=make_dummy_moe_config(),
                quant_config=quant_config,
            ),
            inplace=False,
        )

        cutlass_output = kernel(
            hidden_states=a,
            w1=w1_q,
            w2=w2_q,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=MoEActivation.SWIGLUSTEP,
        )

        # Reference: dequantize everything and run torch_moe with swiglustep
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
        ).to(torch.float32)
        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)

        a_in_dtype = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=a.dtype,
            device=a.device,
            block_size=quant_blocksize,
        )

        w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
        w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
            w1_d[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2_d[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )

        torch_output = torch_moe(
            a_in_dtype,
            w1_d,
            w2_d,
            score,
            topk,
            activation=MoEActivation.SWIGLUSTEP,
        )

        torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)


276
277
if __name__ == "__main__":
    test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)