test_nvfp4_moe.py 5.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import pytest
import torch

6
from tests.kernels.moe.utils import make_test_weights
7
8
9
10
11
12
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
                                                    FLOAT8_E4M3_MAX,
                                                    dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
13
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
14
15
16
17
18
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform

if not current_platform.has_device_capability(100):
bnellnm's avatar
bnellnm committed
19
    pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
                allow_module_level=True)

MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (2, 3072, 1536),
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (64, 2048, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
]


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
                                  dtype: torch.dtype):
    current_platform.seed_everything(7)
    with set_current_vllm_config(
            VllmConfig(parallel_config=ParallelConfig(
                pipeline_parallel_size=1))):

        quant_blocksize = 16
49
50
51
52
53
54
55
56
57
58
59

        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10

        (_, w1_q, w1_blockscale,
         w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
             e,
             n,
             k,
             in_dtype=dtype,
             quant_dtype="nvfp4",
             block_shape=None,  # use quant_blocksize?
60
             per_out_ch_quant=False,
61
         )
62
63

        score = torch.randn((m, e), device="cuda", dtype=dtype)
64
65
66
67
        topk_weights, topk_ids, _ = fused_topk(a,
                                               score,
                                               topk,
                                               renormalize=False)
68
69
70
71

        a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
        a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)

72
73
74
75
76
        assert w1_gs is not None
        assert w2_gs is not None
        assert w1_blockscale is not None
        assert w2_blockscale is not None

77
78
79
80
81
82
83
84
85
        quant_config = nvfp4_moe_quant_config(
            g1_alphas=(1 / w1_gs),
            g2_alphas=(1 / w2_gs),
            a1_gscale=a1_gs,
            a2_gscale=a2_gs,
            w1_scale=w1_blockscale,
            w2_scale=w2_blockscale,
        )

86
87
88
89
90
91
        cutlass_output = cutlass_moe_fp4(
            a=a,
            w1_fp4=w1_q,
            w2_fp4=w2_q,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
92
            quant_config=quant_config,
93
94
95
96
97
98
99
100
101
102
            m=m,
            n=n,
            k=k,
            e=e,
        )

        # Reference check:
        a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
                          torch.amax(a.flatten(), dim=-1)).to(torch.float32)
        a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
        a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
                                               a_scale_interleaved,
                                               a_global_scale,
                                               dtype=a.dtype,
                                               device=a.device,
                                               block_size=quant_blocksize)

        w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
        w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)

        for idx in range(0, e):
            w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
                                                  w1_blockscale[idx],
                                                  w1_gs[idx],
118
119
                                                  dtype=dtype,
                                                  device=w1_q.device,
120
121
122
123
                                                  block_size=quant_blocksize)
            w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
                                                  w2_blockscale[idx],
                                                  w2_gs[idx],
124
125
                                                  dtype=dtype,
                                                  device=w2_q.device,
126
127
                                                  block_size=quant_blocksize)

128
        torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
129
130
131
132
133
134
135
136
137

        torch.testing.assert_close(torch_output,
                                   cutlass_output,
                                   atol=1e-1,
                                   rtol=1e-1)


if __name__ == "__main__":
    test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)