test_block_fp8.py 9.86 KB
Newer Older
bnellnm's avatar
bnellnm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
                                       native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
    _valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform

dg_available = False
try:
    import deep_gemm
    dg_available = True
except ImportError:
    pass

if current_platform.get_device_capability() < (9, 0):
    pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
                allow_module_level=True)

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

# Test configurations
DTYPES = [torch.bfloat16]  # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
    (1, 128, 128),
    (1, 512, 512),
    (1, 128, 7168),
    (1, 1024, 7168),
    (1, 4608, 128),
    (1, 4608, 512),
    (1, 4608, 7168),
    (83, 128, 128),
    (83, 512, 512),
    (83, 1024, 7168),
    (83, 4608, 512),
    (83, 4608, 7168),
    (128, 128, 128),
    (128, 512, 512),
    (128, 1024, 7168),
    (128, 4608, 512),
    (128, 4608, 7168),
    (2048, 128, 128),
    (2048, 1024, 7168),
    (2048, 4608, 512),
    (2048, 4608, 7168),
    (8192, 128, 128),
    (8192, 512, 512),
    (8192, 128, 7168),
    (8192, 1024, 7168),
    (8192, 4608, 512),
    (8192, 4608, 7168),
]

MNK_FACTORS_DG = [
    (128, 128, 128),
    (128, 512, 512),
    (128, 128, 7168),
    (128, 1024, 7168),
    (128, 4608, 128),
    (128, 4608, 512),
    (128, 4608, 7168),
    (192, 128, 128),
    (192, 512, 512),
    (192, 1024, 7168),
    (192, 4608, 512),
    (192, 4608, 7168),
    (1335, 128, 128),
    (1335, 1024, 7168),
    (1335, 4608, 512),
    (1335, 4608, 7168),
    (2048, 128, 128),
    (2048, 512, 512),
    (2048, 128, 7168),
    (2048, 1024, 7168),
    (2048, 4608, 128),
    (2048, 4608, 512),
    (2048, 4608, 7168),
]

BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16]  # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]


def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
                             block_shape):
    """Fused moe with block-wise quantization using native torch."""
    B, D = a.shape
    topk = topk_ids.size(1)
    a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
    out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)

    topk_weight = topk_weight.view(-1)
    topk_ids = topk_ids.view(-1)

    _, block_k = block_shape[0], block_shape[1]
    a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
    a_q = a_q.to(torch.float32)
    for i in range(w1.shape[0]):
        mask = topk_ids == i
        if mask.sum():
            inter_out = native_w8a8_block_matmul(a_q[mask],
                                                 w1[i],
                                                 a_s[mask],
                                                 w1_s[i],
                                                 block_shape,
                                                 output_dtype=a.dtype)
            act_out = SiluAndMul().forward_native(inter_out)
            act_out_q, act_out_s = native_per_token_group_quant_fp8(
                act_out, block_k)
            out[mask] = native_w8a8_block_matmul(act_out_q,
                                                 w2[i],
                                                 act_out_s,
                                                 w2_s[i],
                                                 block_shape,
                                                 output_dtype=a.dtype)
    return (out.view(B, -1, w2.shape[1]) *
            topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")


@pytest.fixture(autouse=True)
def setup_cuda():
    torch.set_default_device("cuda")


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
                                  monkeypatch):
    if topk > E:
        pytest.skip(f"Skipping test; topk={topk} > E={E}")

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

    _, w1, w1_s, _, w2, w2_s = make_test_weights(E,
                                                 N,
                                                 K,
                                                 dtype,
                                                 torch.float8_e4m3fn,
                                                 per_act_token_quant=False,
                                                 block_shape=block_size)

    m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
                                           use_int8_w8a8=False,
                                           use_int8_w8a16=False,
                                           use_int4_w4a16=False,
                                           per_act_token_quant=False,
                                           block_shape=block_size)

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
        ref_out = torch_w8a8_block_fp8_moe(
            a,
            w1,
            w2,
            w1_s,
            w2_s,
            topk_weights,
            topk_ids,
            block_size,
        )

        out = fused_experts(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            use_fp8_w8a8=True,
            w1_scale=w1_s,
            w2_scale=w2_s,
            block_shape=block_size,
        )

        m_out = m_fused_moe(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
        )

    # 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
    tol = 0.035 if M < 40000 else 0.039
    torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
    torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)


@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
                                            monkeypatch):
    if topk > E:
        pytest.skip(f"Skipping test: topk={topk} > E={E}")

    if not _valid_deep_gemm_shape(M, N, K):
        pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")

    chunk_size = 1024

    torch.manual_seed(seed)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

    block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
    block_size = [block_m, block_m]
    dtype = torch.bfloat16

    a = torch.randn((M, K), dtype=dtype) / 10
    score = torch.randn((M, E), dtype=dtype)

    _, w1, w1_s, _, w2, w2_s = make_test_weights(E,
                                                 N,
                                                 K,
                                                 dtype,
                                                 torch.float8_e4m3fn,
                                                 per_act_token_quant=False,
                                                 block_shape=block_size)

    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False

    use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
                     and current_platform.is_cuda_alike())

    topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)

    # Set the context to avoid lots of warning spam.
    with set_current_vllm_config(vllm_config):
        ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
                                           topk_ids, block_size)

        if use_compile:
            deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
                                                 backend="inductor",
                                                 fullgraph=True)
            torch._dynamo.mark_dynamic(a, 0)
            torch._dynamo.mark_dynamic(topk_weights, 0)
            torch._dynamo.mark_dynamic(topk_ids, 0)
        else:
            deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8

        out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
                                   topk_ids)

        if use_cudagraph:
            out.fill_(0)
            stream = torch.cuda.Stream()
            graph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(graph, stream=stream):
                out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
                                           topk_ids)
            torch.cuda.synchronize()
            graph.replay()
            torch.cuda.synchronize()

    torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)