test_block_fp8.py 9.71 KB
Newer Older
bnellnm's avatar
bnellnm committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

7
8
9
10
11
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import (
    make_dummy_moe_config,
    make_test_quant_config,
    make_test_weights,
12
    modular_triton_fused_moe,
13
)
14
15
16
17
from tests.kernels.quant_utils import (
    native_per_token_group_quant_fp8,
    native_w8a8_block_matmul,
)
bnellnm's avatar
bnellnm committed
18
19
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
20
21
22
23
from vllm.model_executor.layers.fused_moe import (
    fused_experts,
    fused_topk,
)
24
25
26
27
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
28
29
30
from vllm.model_executor.layers.fused_moe.config import (
    fp8_w8a8_moe_quant_config,
)
bnellnm's avatar
bnellnm committed
31
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
32
33
    _valid_deep_gemm_shape,
)
34
35
36
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
    TritonOrDeepGemmExperts,
)
bnellnm's avatar
bnellnm committed
37
from vllm.platforms import current_platform
38
39
40
41
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
)
42
from vllm.utils.import_utils import has_deep_gemm
bnellnm's avatar
bnellnm committed
43

44
45
dg_available = has_deep_gemm()

bnellnm's avatar
bnellnm committed
46
if current_platform.get_device_capability() < (9, 0):
47
    pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
48
49
50
51
52
if current_platform.is_fp8_fnuz():
    pytest.skip(
        "Tests in this file require float8_e4m3fn and platform does not support",
        allow_module_level=True,
    )
bnellnm's avatar
bnellnm committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

vllm_config = VllmConfig()

# Test configurations
DTYPES = [torch.bfloat16]  # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 7168),
    (1, 1024, 7168),
    (1, 4608, 128),
    (1, 4608, 7168),
    (83, 128, 128),
    (83, 512, 512),
    (83, 4608, 512),
    (83, 4608, 7168),
    (128, 512, 512),
    (128, 1024, 7168),
    (128, 4608, 7168),
    (2048, 128, 128),
    (2048, 1024, 7168),
    (2048, 4608, 512),
    (2048, 4608, 7168),
    (8192, 128, 128),
    (8192, 128, 7168),
    (8192, 1024, 7168),
    (8192, 4608, 7168),
]

MNK_FACTORS_DG = [
    (128, 128, 128),
    (128, 128, 7168),
    (128, 1024, 7168),
    (128, 4608, 128),
    (128, 4608, 7168),
    (192, 512, 512),
    (192, 1024, 7168),
    (192, 4608, 7168),
    (1335, 128, 128),
    (1335, 1024, 7168),
    (1335, 4608, 512),
    (1335, 4608, 7168),
    (2048, 128, 128),
    (2048, 128, 7168),
    (2048, 1024, 7168),
    (2048, 4608, 7168),
]

BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16]  # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]


108
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape):
bnellnm's avatar
bnellnm committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    """Fused moe with block-wise quantization using native torch."""
    B, D = a.shape
    topk = topk_ids.size(1)
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)

    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)

    _, block_k = block_shape[0], block_shape[1]
    a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
    a_q = a_q.to(torch.float32)
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
124
125
126
            inter_out = native_w8a8_block_matmul(
                a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
            )
bnellnm's avatar
bnellnm committed
127
            act_out = SiluAndMul().forward_native(inter_out)
128
129
130
131
132
133
134
            act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
            out[mask] = native_w8a8_block_matmul(
                act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
            )
    return (
        out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
    ).sum(dim=1)
bnellnm's avatar
bnellnm committed
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152


# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")


@pytest.fixture(autouse=True)
def setup_cuda():
    torch.set_default_device("cuda")


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
153
def test_w8a8_block_fp8_fused_moe(
154
    M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
155
):
bnellnm's avatar
bnellnm committed
156
157
158
159
160
161
162
163
164
165
    if topk > E:
        pytest.skip(f"Skipping test; topk={topk} > E={E}")

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

166
167
168
169
170
171
172
173
174
175
    w1, w2, quant_config = make_test_quant_config(
        E,
        N,
        K,
        dtype,
        quant_dtype=torch.float8_e4m3fn,
        per_act_token_quant=False,
        block_shape=block_size,
    )

176
    m_fused_moe = modular_triton_fused_moe(make_dummy_moe_config(), quant_config)
bnellnm's avatar
bnellnm committed
177
178
179
180
181
182
183
184
185

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
        ref_out = torch_w8a8_block_fp8_moe(
            a,
            w1,
            w2,
186
187
            quant_config.w1_scale,
            quant_config.w2_scale,
bnellnm's avatar
bnellnm committed
188
189
190
191
192
            topk_weights,
            topk_ids,
            block_size,
        )

193
194
195
        out = fused_experts(
            a, w1, w2, topk_weights, topk_ids, quant_config=quant_config
        )
bnellnm's avatar
bnellnm committed
196

197
198
199
200
201
202
203
204
205
206
207
        m_out = m_fused_moe.apply(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            activation=MoEActivation.SILU,
            apply_router_weight_on_input=False,
            expert_map=None,
            global_num_experts=w1.shape[0],
        )
bnellnm's avatar
bnellnm committed
208

209
210
    # 0.039 only needed for M >= 8192
    tol = 0.035 if M < 8192 else 0.039
bnellnm's avatar
bnellnm committed
211
212
213
214
215
216
217
218
219
    torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
    torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
220
@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE")
bnellnm's avatar
bnellnm committed
221
@torch.inference_mode()
222
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch):
bnellnm's avatar
bnellnm committed
223
224
225
226
227
228
229
230
231
232
233
    if topk > E:
        pytest.skip(f"Skipping test: topk={topk} > E={E}")

    if not _valid_deep_gemm_shape(M, N, K):
        pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")

    chunk_size = 1024

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
234
    block_size = get_mk_alignment_for_contiguous_layout()
bnellnm's avatar
bnellnm committed
235
236
237
238
239
    dtype = torch.bfloat16

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

240
241
242
243
244
245
246
247
248
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
        E,
        N,
        K,
        dtype,
        torch.float8_e4m3fn,
        per_out_ch_quant=False,
        block_shape=block_size,
    )
bnellnm's avatar
bnellnm committed
249
250
251
252
253
254

    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False

255
256
257
    use_cudagraph = (
        chunk_size < M and N >= 1024 and K >= 1024 and current_platform.is_cuda_alike()
    )
bnellnm's avatar
bnellnm committed
258
259
260

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

261
262
263
264
265
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_s,
        w2_scale=w2_s,
        block_shape=block_size,
    )
266
    moe_config = make_dummy_moe_config()
267

268
269
270
271
272
273
274
    deep_gemm_experts = mk.FusedMoEKernel(
        prepare_finalize=maybe_make_prepare_finalize(
            moe=moe_config,
            quant_config=quant_config,
            allow_new_interface=True,
            use_monolithic=False,
        ),
275
        fused_experts=TritonOrDeepGemmExperts(
276
            moe_config=moe_config,
277
278
            quant_config=quant_config,
        ),
279
        inplace=False,
280
281
282
    )

    def deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids):
283
        return deep_gemm_experts.apply(
284
285
286
287
288
            hidden_states=a,
            w1=w1,
            w2=w2,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
289
290
291
292
            global_num_experts=E,
            activation=MoEActivation.SILU,
            apply_router_weight_on_input=False,
            expert_map=False,
293
294
        )

bnellnm's avatar
bnellnm committed
295
296
    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
297
298
299
        ref_out = torch_w8a8_block_fp8_moe(
            a, w1, w2, w1_s, w2_s, topk_weights, topk_ids, block_size
        )
bnellnm's avatar
bnellnm committed
300
301

        if use_compile:
302
303
304
            deep_gemm_moe_fp8_fn = torch.compile(
                deep_gemm_moe_fp8, backend="inductor", fullgraph=True
            )
bnellnm's avatar
bnellnm committed
305
306
307
308
309
310
            torch._dynamo.mark_dynamic(a, 0)
            torch._dynamo.mark_dynamic(topk_weights, 0)
            torch._dynamo.mark_dynamic(topk_ids, 0)
        else:
            deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8

311
        out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
bnellnm's avatar
bnellnm committed
312
313
314
315
316
317

        if use_cudagraph:
            out.fill_(0)
            stream = torch.cuda.Stream()
            graph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(graph, stream=stream):
318
319
320
                out = deep_gemm_moe_fp8_fn(
                    a, w1, w2, w1_s, w2_s, topk_weights, topk_ids
                )
321
            torch.accelerator.synchronize()
bnellnm's avatar
bnellnm committed
322
            graph.replay()
323
            torch.accelerator.synchronize()
bnellnm's avatar
bnellnm committed
324
325

    torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)