test_nvfp4_moe.py 4.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import pytest
import torch

6
from tests.kernels.moe.utils import make_test_weights
7
8
9
10
11
from tests.kernels.quantization.nvfp4_utils import (
    FLOAT4_E2M1_MAX,
    FLOAT8_E4M3_MAX,
    dequantize_nvfp4_to_dtype,
)
12
13
14
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
15
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
16
17
18
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
19
from vllm.utils.torch_utils import set_random_seed
20
21

if not current_platform.has_device_capability(100):
22
23
24
    pytest.skip(
        "Nvfp4 Requires compute capability of 10 or above.", allow_module_level=True
    )
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (64, 1024, 1024),
    (64, 3072, 1024),
    (64, 2048, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
]


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
41
@pytest.mark.parametrize("dtype", [torch.bfloat16])
42
@torch.inference_mode()
43
def test_cutlass_fp4_moe_no_graph(
44
    m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
45
):
46
    set_random_seed(7)
47
    with set_current_vllm_config(
48
49
        VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
    ):
50
        quant_blocksize = 16
51
52
53

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

54
55
56
57
58
59
60
61
62
63
64
        (_, w1_q, w1_blockscale, w1_gs), (_, w2_q, w2_blockscale, w2_gs) = (
            make_test_weights(
                e,
                n,
                k,
                in_dtype=dtype,
                quant_dtype="nvfp4",
                block_shape=None,  # use quant_blocksize?
                per_out_ch_quant=False,
            )
        )
65
66

        score = torch.randn((m, e), device="cuda", dtype=dtype)
67
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
68

69
70
        a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
        a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
71

72
73
74
75
76
        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

77
78
79
80
81
82
83
84
85
        quant_config = nvfp4_moe_quant_config(
            g1_alphas=(1 / w1_gs),
            g2_alphas=(1 / w2_gs),
            a1_gscale=a1_gs,
            a2_gscale=a2_gs,
            w1_scale=w1_blockscale,
            w2_scale=w2_blockscale,
        )

86
87
88
89
90
91
        cutlass_output = cutlass_moe_fp4(
            a=a,
            w1_fp4=w1_q,
            w2_fp4=w2_q,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
92
            quant_config=quant_config,
93
94
95
96
97
98
99
            m=m,
            n=n,
            k=k,
            e=e,
        )

        # Reference check:
100
101
102
        a_global_scale = (
            (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a.flatten(), dim=-1)
        ).to(torch.float32)
103
        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
104

105
106
107
108
109
110
111
112
        a_in_dtype = dequantize_nvfp4_to_dtype(
            a_fp4,
            a_scale_interleaved,
            a_global_scale,
            dtype=a.dtype,
            device=a.device,
            block_size=quant_blocksize,
        )
113
114
115
116
117

        w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
        w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            w1_d[idx] = dequantize_nvfp4_to_dtype(
                w1_q[idx],
                w1_blockscale[idx],
                w1_gs[idx],
                dtype=dtype,
                device=w1_q.device,
                block_size=quant_blocksize,
            )
            w2_d[idx] = dequantize_nvfp4_to_dtype(
                w2_q[idx],
                w2_blockscale[idx],
                w2_gs[idx],
                dtype=dtype,
                device=w2_q.device,
                block_size=quant_blocksize,
            )
134

135
        torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
136

137
        torch.testing.assert_close(torch_output, cutlass_output, atol=1e-1, rtol=1e-1)
138
139
140
141


if __name__ == "__main__":
    test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)