test_deepgemm.py 5.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
"""
Unit-test DeepGEMM FP8 kernels (no DeepEP).
Compare DeepGEMM path against the Triton fallback inside vLLM's fused_experts.
"""

import importlib
import math

import pytest
import torch

14
15
from vllm.model_executor.layers.fused_moe.config import (
    fp8_w8a8_moe_quant_config)
16
17
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
18
19
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    per_token_group_quant_fp8)
20
21
from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
                                  per_block_cast_to_fp8)
22

23
BLOCK_SIZE = [128, 128]
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
):
    """
    Generate (w1, w2) expert weights and their per-block scale tensors
    in FP8 block-quantized format.

      w1 shape: (E, 2N, K)
      w2 shape: (E, K, N)
    """
    dtype = torch.bfloat16
    fp8_max, fp8_min = torch.finfo(torch.float8_e4m3fn).max, torch.finfo(
        torch.float8_e4m3fn).min

    # bf16 reference weights
    w1_bf16 = torch.randn(e, 2 * n, k, device="cuda", dtype=dtype) / 10
    w2_bf16 = torch.randn(e, k, n, device="cuda", dtype=dtype) / 10
    w1_bf16.clamp_(fp8_min, fp8_max)
    w2_bf16.clamp_(fp8_min, fp8_max)

    block_n, block_k = block_size
    n_tiles_w1 = math.ceil((2 * n) / block_n)
    k_tiles_w1 = math.ceil(k / block_k)
    n_tiles_w2 = math.ceil(k / block_n)
    k_tiles_w2 = math.ceil(n / block_k)

    w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
    w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
    w1_s = torch.empty(e,
                       n_tiles_w1,
                       k_tiles_w1,
                       device="cuda",
                       dtype=torch.float32)
    w2_s = torch.empty(e,
                       n_tiles_w2,
                       k_tiles_w2,
                       device="cuda",
                       dtype=torch.float32)

    for i in range(e):
69
70
71
72
73
74
        w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
                                               block_size=block_size,
                                               use_ue8m0=True)
        w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
                                               block_size=block_size,
                                               use_ue8m0=True)
75
76
77
78
79
80
81
82
83
84
85

    return w1, w2, w1_s, w2_s


def run_single_case(m, n, k, topk, num_experts, block_size):
    """
    Run one (M,N,K) configuration on a single GPU and assert DeepGEMM ==
    Triton baseline within tolerance.
    """
    tokens_bf16 = torch.randn(
        m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
86
    _, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
87
88
89
90
91
92
93
94
95
96
97
98

    # expert weight tensors
    w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
                                                      block_size)

    router_logits = torch.randn(m,
                                num_experts,
                                device="cuda",
                                dtype=torch.float32)
    topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
    topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)

99
100
101
102
103
104
105
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_s,
        w2_scale=w2_s,
        a1_scale=a1_scale,
        block_shape=block_size,
    )

106
    # triton reference
107
108
109
110
111
112
113
    out_triton = fused_experts(
        hidden_states=tokens_bf16,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
114
        quant_config=quant_config,
115
116
117
118
119
120
121
122
123
124
125
        allow_deep_gemm=False,
    )

    # DeepGemm
    out_deepgemm = fused_experts(
        hidden_states=tokens_bf16,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
126
        quant_config=quant_config,
127
128
        allow_deep_gemm=True,
    )
129
130
    diff = calc_diff(out_deepgemm, out_triton)
    assert diff < 0.001, f"Diff exceeded 1%: {diff}"
131
132


133
# Note: N <= 512 will disable the deepgemm path due to performance issues.
134
MNKs = [
135
136
137
    (1024, 768, 128),
    (1024, 768, 512),
    (2048, 768, 512),
138
139
140
141
142
143
144
145
146
    (512, 1024, 1024),
    (512, 2048, 2048),
    (4096, 4096, 1024),
]

TOPKS = [2, 6]
NUM_EXPERTS = [32]


147
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
148
149
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
150
151
@pytest.mark.skipif(not is_deep_gemm_supported(),
                    reason="Requires deep_gemm kernels")
152
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
153

154
155
    with monkeypatch.context() as mp:
        mp.setenv("VLLM_USE_DEEP_GEMM", "1")
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        _fused_moe_mod = importlib.import_module(
            "vllm.model_executor.layers.fused_moe.fused_moe")

        call_counter = {"cnt": 0}

        orig_fn = _fused_moe_mod.deep_gemm_moe_fp8

        def _spy_deep_gemm_moe_fp8(*args, **kwargs):
            call_counter["cnt"] += 1
            return orig_fn(*args, **kwargs)

        monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
                            _spy_deep_gemm_moe_fp8)

        if topk > num_experts:
            pytest.skip(f"topk={topk} > num_experts={num_experts}")

        run_single_case(
            m=m,
            n=n,
            k=k,
            topk=topk,
            num_experts=num_experts,
            block_size=BLOCK_SIZE,
        )

        # ensure that the DeepGEMM path was indeed taken.
        assert call_counter["cnt"] == 1, \
            f"DeepGEMM path was not executed during the test. " \
            f"Call counter: {call_counter['cnt']}"