test_block_fp8.py 9.1 KB
Newer Older
bnellnm's avatar
bnellnm committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

7
8
9
10
11
12
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import (
    make_dummy_moe_config,
    make_test_quant_config,
    make_test_weights,
)
13
14
15
16
from tests.kernels.quant_utils import (
    native_per_token_group_quant_fp8,
    native_w8a8_block_matmul,
)
bnellnm's avatar
bnellnm committed
17
18
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
19
20
21
22
from vllm.model_executor.layers.fused_moe import (
    fused_experts,
    fused_topk,
)
23
24
25
from vllm.model_executor.layers.fused_moe.config import (
    fp8_w8a8_moe_quant_config,
)
bnellnm's avatar
bnellnm committed
26
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
27
28
    _valid_deep_gemm_shape,
)
bnellnm's avatar
bnellnm committed
29
from vllm.model_executor.layers.fused_moe.fused_moe import (
30
31
    modular_triton_fused_moe,
)
32
33
34
35
36
37
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts,
)
bnellnm's avatar
bnellnm committed
38
from vllm.platforms import current_platform
39
40
41
42
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
)
43
from vllm.utils.import_utils import has_deep_gemm
bnellnm's avatar
bnellnm committed
44

45
46
dg_available = has_deep_gemm()

bnellnm's avatar
bnellnm committed
47
if current_platform.get_device_capability() < (9, 0):
48
    pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
49
50
51
52
53
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )
bnellnm's avatar
bnellnm committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

vllm_config = VllmConfig()

# Test configurations
DTYPES = [torch.bfloat16]  # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 7168),
    (1, 1024, 7168),
    (1, 4608, 128),
    (1, 4608, 7168),
    (83, 128, 128),
    (83, 512, 512),
    (83, 4608, 512),
    (83, 4608, 7168),
    (128, 512, 512),
    (128, 1024, 7168),
    (128, 4608, 7168),
    (2048, 128, 128),
    (2048, 1024, 7168),
    (2048, 4608, 512),
    (2048, 4608, 7168),
    (8192, 128, 128),
    (8192, 128, 7168),
    (8192, 1024, 7168),
    (8192, 4608, 7168),
]

MNK_FACTORS_DG = [
    (128, 128, 128),
    (128, 128, 7168),
    (128, 1024, 7168),
    (128, 4608, 128),
    (128, 4608, 7168),
    (192, 512, 512),
    (192, 1024, 7168),
    (192, 4608, 7168),
    (1335, 128, 128),
    (1335, 1024, 7168),
    (1335, 4608, 512),
    (1335, 4608, 7168),
    (2048, 128, 128),
    (2048, 128, 7168),
    (2048, 1024, 7168),
    (2048, 4608, 7168),
]

BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16]  # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]


109
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
bnellnm's avatar
bnellnm committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    """Fused moe with block-wise quantization using native torch."""
    B, D = a.shape
    topk = topk_ids.size(1)
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)

    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)

    _, block_k = block_shape[0], block_shape[1]
    a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
    a_q = a_q.to(torch.float32)
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
125
126
127
            inter_out = native_w8a8_block_matmul(
                a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
            )
bnellnm's avatar
bnellnm committed
128
            act_out = SiluAndMul().forward_native(inter_out)
129
130
131
132
133
134
135
            act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
            out[mask] = native_w8a8_block_matmul(
                act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
            )
    return (
        out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
    ).sum(dim=1)
bnellnm's avatar
bnellnm committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")


@pytest.fixture(autouse=True)
def setup_cuda():
    torch.set_default_device("cuda")


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
154
def test_w8a8_block_fp8_fused_moe(
155
    M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
156
):
bnellnm's avatar
bnellnm committed
157
158
159
160
161
162
163
164
165
166
    if topk > E:
        pytest.skip(f"Skipping test; topk={topk} > E={E}")

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

167
168
169
170
171
172
173
174
175
176
    w1, w2, quant_config = make_test_quant_config(
        E,
        N,
        K,
        dtype,
        quant_dtype=torch.float8_e4m3fn,
        per_act_token_quant=False,
        block_shape=block_size,
    )

177
    m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
bnellnm's avatar
bnellnm committed
178
179
180
181
182
183
184
185
186

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
        ref_out = torch_w8a8_block_fp8_moe(
            a,
            w1,
            w2,
187
188
            quant_config.w1_scale,
            quant_config.w2_scale,
bnellnm's avatar
bnellnm committed
189
190
191
192
193
            topk_weights,
            topk_ids,
            block_size,
        )

194
195
196
        out = fused_experts(
            a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
        )
bnellnm's avatar
bnellnm committed
197

198
        m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
bnellnm's avatar
bnellnm committed
199

200
201
    # 0.039 only needed for M >= 8192
    tol = 0.035 if M < 8192 else 0.039
bnellnm's avatar
bnellnm committed
202
203
204
205
206
207
208
209
210
    torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
    torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
211
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
bnellnm's avatar
bnellnm committed
212
@torch.inference_mode()
213
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
bnellnm's avatar
bnellnm committed
214
215
216
217
218
219
220
221
222
223
224
    if topk > E:
        pytest.skip(f"Skipping test: topk={topk} > E={E}")

    if not _valid_deep_gemm_shape(M, N, K):
        pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")

    chunk_size = 1024

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
225
    block_size = get_mk_alignment_for_contiguous_layout()
bnellnm's avatar
bnellnm committed
226
227
228
229
230
    dtype = torch.bfloat16

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

231
232
233
234
235
236
237
238
239
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        E,
        N,
        K,
        dtype,
        torch.float8_e4m3fn,
        per_out_ch_quant=False,
        block_shape=block_size,
    )
bnellnm's avatar
bnellnm committed
240
241
242
243
244
245

    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False

246
247
248
    use_cudagraph = (
        chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
    )
bnellnm's avatar
bnellnm committed
249
250
251

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_s,
        w2_scale=w2_s,
        block_shape=block_size,
    )

    deep_gemm_experts = mk.FusedMoEModularKernel(
        prepare_finalize=MoEPrepareAndFinalizeNoEP(),
        fused_experts=TritonOrDeepGemmExperts(
            moe_config=make_dummy_moe_config(),
            quant_config=quant_config,
        ),
    )

    def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
        return deep_gemm_experts(
            hidden_states=a,
            w1=w1,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
        )

bnellnm's avatar
bnellnm committed
275
276
    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
277
278
279
        ref_out = torch_w8a8_block_fp8_moe(
            a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
        )
bnellnm's avatar
bnellnm committed
280
281

        if use_compile:
282
283
284
            deep_gemm_moe_fp8_fn = torch.compile(
                deep_gemm_moe_fp8, backend="inductor", fullgraph=True
            )
bnellnm's avatar
bnellnm committed
285
286
287
288
289
290
            torch._dynamo.mark_dynamic(a, 0)
            torch._dynamo.mark_dynamic(topk_weights, 0)
            torch._dynamo.mark_dynamic(topk_ids, 0)
        else:
            deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8

291
        out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
bnellnm's avatar
bnellnm committed
292
293
294
295
296
297

        if use_cudagraph:
            out.fill_(0)
            stream = torch.cuda.Stream()
            graph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(graph, stream=stream):
298
299
300
                out = deep_gemm_moe_fp8_fn(
                    a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
                )
bnellnm's avatar
bnellnm committed
301
302
303
304
305
            torch.cuda.synchronize()
            graph.replay()
            torch.cuda.synchronize()

    torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)