test_deepgemm.py 6.06 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""

import importlib
import math

import pytest
import torch

# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
15
16
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_dummy_moe_config
17
18
19
20
21
22
from vllm.model_executor.layers.fused_moe.activation import (
    MoEActivation,
)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
23
24
25
from vllm.model_executor.layers.fused_moe.config import (
    fp8_w8a8_moe_quant_config,
)
26
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
27
28
29
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts,
)
30
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
31
32
33
34
35
36
37
    per_token_group_quant_fp8,
)
from vllm.utils.deep_gemm import (
    calc_diff,
    is_deep_gemm_supported,
    per_block_cast_to_fp8,
)
38

39
BLOCK_SIZE = [128, 128]
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
):
    """
    Generate (w1, w2) expert weights and their per-block scale tensors
    in FP8 block-quantized format.

      w1 shape: (E, 2N, K)
      w2 shape: (E, K, N)
    """
    dtype = torch.bfloat16
56
57
58
59
    fp8_max, fp8_min = (
        torch.finfo(torch.float8_e4m3fn).max,
        torch.finfo(torch.float8_e4m3fn).min,
    )
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    # bf16 reference weights
    w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
    w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
    w1_bf16.clamp_(fp8_min, fp8_max)
    w2_bf16.clamp_(fp8_min, fp8_max)

    block_n, block_k = block_size
    n_tiles_w1 = math.ceil((2 * n) / block_n)
    k_tiles_w1 = math.ceil(k / block_k)
    n_tiles_w2 = math.ceil(k / block_n)
    k_tiles_w2 = math.ceil(n / block_k)

    w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
    w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
75
76
    w1_s = torch.empty(e, n_tiles_w1, k_tiles_w1, device="cuda", dtype=torch.float32)
    w2_s = torch.empty(e, n_tiles_w2, k_tiles_w2, device="cuda", dtype=torch.float32)
77
78

    for i in range(e):
79
80
81
82
83
84
        w1[i], w1_s[i] = per_block_cast_to_fp8(
            w1_bf16[i], block_size=block_size, use_ue8m0=True
        )
        w2[i], w2_s[i] = per_block_cast_to_fp8(
            w2_bf16[i], block_size=block_size, use_ue8m0=True
        )
85
86
87
88
89
90
91
92
93

    return w1, w2, w1_s, w2_s


def run_single_case(m, n, k, topk, num_experts, block_size):
    """
    Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
    Triton baseline within tolerance.
    """
94
95
96
97
98
    tokens_bf16 = (
        torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
        .clamp_min_(-1)
        .clamp_max_(1)
    )
99
    _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
100
101

    # expert weight tensors
102
    w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k, block_size)
103

104
    router_logits = torch.randn(m, num_experts, device="cuda", dtype=torch.float32)
105
106
107
    topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
    topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)

108
109
110
111
112
113
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_s,
        w2_scale=w2_s,
        a1_scale=a1_scale,
        block_shape=block_size,
    )
114
    moe_config = make_dummy_moe_config()
115

116
117
118
119
120
121
122
    deep_gemm_experts = mk.FusedMoEKernel(
        prepare_finalize=maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
123
        fused_experts=TritonOrDeepGemmExperts(
124
            moe_config=moe_config,
125
126
            quant_config=quant_config,
        ),
127
        inplace=False,
128
129
    )

130
    # triton reference
131
132
133
134
135
136
137
    out_triton = fused_experts(
        hidden_states=tokens_bf16,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
138
        quant_config=quant_config,
139
140
141
    )

    # DeepGemm
142
    out_deepgemm = deep_gemm_experts.apply(
143
144
145
146
147
        hidden_states=tokens_bf16,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
148
149
150
151
        global_num_experts=num_experts,
        activation=MoEActivation.SILU,
        apply_router_weight_on_input=False,
        expert_map=None,
152
    )
153
154
    diff = calc_diff(out_deepgemm, out_triton)
    assert diff < 0.001, f"Diff exceeded 1%: {diff}"
155
156


157
# Note: N <= 512 will disable the deepgemm path due to performance issues.
158
MNKs = [
159
160
    (1024, 768, 128),
    (2048, 768, 512),
161
162
163
164
165
166
167
168
    (512, 1024, 1024),
    (4096, 4096, 1024),
]

TOPKS = [2, 6]
NUM_EXPERTS = [32]


169
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
170
171
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
172
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
173
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
174
175
    with monkeypatch.context() as mp:
        mp.setenv("VLLM_USE_DEEP_GEMM", "1")
176

177
178
179
        _DeepGemmExperts = importlib.import_module(
            "vllm.model_executor.layers.fused_moe.deep_gemm_moe"
        ).DeepGemmExperts
180
181
182

        call_counter = {"cnt": 0}

183
        orig_fn = _DeepGemmExperts.apply
184

185
        def _spy_apply(*args, **kwargs):
186
187
188
            call_counter["cnt"] += 1
            return orig_fn(*args, **kwargs)

189
        monkeypatch.setattr(_DeepGemmExperts, "apply", _spy_apply)
190
191
192
193
194
195
196
197
198
199
200
201
202
        if topk > num_experts:
            pytest.skip(f"topk={topk} > num_experts={num_experts}")

        run_single_case(
            m=m,
            n=n,
            k=k,
            topk=topk,
            num_experts=num_experts,
            block_size=BLOCK_SIZE,
        )

        # ensure that the DeepGEMM path was indeed taken.
203
204
        assert call_counter["cnt"] == 1, (
            f"DeepGEMM path was not executed during the test. "
205
            f"Call counter: {call_counter['cnt']}"
206
        )