test_batched_moe.py 9.69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
bnellnm's avatar
bnellnm committed
5
from typing import Optional
6
7
8
9

import pytest
import torch

bnellnm's avatar
bnellnm committed
10
11
from tests.kernels.moe.utils import (batched_moe,
                                     make_quantized_test_activations,
12
                                     make_test_weights, naive_batched_moe)
bnellnm's avatar
bnellnm committed
13
14
15
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
16
17
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
    invoke_moe_batched_triton_kernel)
bnellnm's avatar
bnellnm committed
18
19
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
20
from vllm.triton_utils import tl
bnellnm's avatar
bnellnm committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 2048),
    (1, 512, 512),
    (1, 1024, 128),
    (1, 1024, 2048),
    (32, 128, 128),
    (32, 512, 512),
    (32, 1024, 2048),
    (45, 128, 128),
    (45, 128, 2048),
    (45, 512, 512),
    (45, 1024, 128),
    (45, 1024, 2048),
    (64, 512, 512),
    (64, 1024, 2048),
    (222, 128, 128),
    (222, 128, 2048),
    (222, 1024, 128),
    (222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
49
50
51
52


@dataclass
class BatchedMMConfig:
bnellnm's avatar
bnellnm committed
53
54
55
    in_dtype: torch.dtype
    quant_dtype: Optional[torch.dtype]
    out_dtype: torch.dtype
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    num_experts: int
    max_tokens_per_expert: int
    K: int
    N: int


@dataclass
class BatchedMMTensors:
    A: torch.Tensor  # [E, max_tokens, K]
    B: torch.Tensor  # [E, K, N] - column major
    C: torch.Tensor  # [E, max_tokens, N]
    num_expert_tokens: torch.Tensor  # [E]

    @staticmethod
    def make_tensors(config: BatchedMMConfig):
        A = torch.randn(
            (config.num_experts, config.max_tokens_per_expert, config.K),
            device="cuda",
bnellnm's avatar
bnellnm committed
74
            dtype=config.in_dtype) / 10
75
76
        B = torch.randn((config.num_experts, config.N, config.K),
                        device="cuda",
bnellnm's avatar
bnellnm committed
77
                        dtype=config.in_dtype)
78
79
80
        C = torch.zeros(
            (config.num_experts, config.max_tokens_per_expert, config.N),
            device="cuda",
bnellnm's avatar
bnellnm committed
81
82
            dtype=config.out_dtype)

83
84
85
86
87
88
        num_expert_tokens = torch.randint(low=0,
                                          high=config.max_tokens_per_expert,
                                          size=(config.num_experts, ),
                                          device="cuda",
                                          dtype=torch.int32)

bnellnm's avatar
bnellnm committed
89
        return BatchedMMTensors(A, B, C, num_expert_tokens)
90
91


92
93
94
95
96
@pytest.mark.parametrize("num_experts", [8, 32])
@pytest.mark.parametrize("max_tokens_per_expert", [32, 224, 512])
@pytest.mark.parametrize("K", [128, 1024])
@pytest.mark.parametrize("N", [128, 1024])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
97
98
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
99
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
bnellnm's avatar
bnellnm committed
100
101
102
103
                    N: int, dtype: torch.dtype,
                    block_shape: Optional[list[int]],
                    per_act_token_quant: bool):
    current_platform.seed_everything(7)
104

bnellnm's avatar
bnellnm committed
105
    use_fp8_w8a8 = dtype == torch.float8_e4m3fn
106

bnellnm's avatar
bnellnm committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
        pytest.skip("Don't test blocking for non-quantized types.")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization test.")

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

    num_expert_tokens = torch.randint(low=0,
                                      high=max_tokens_per_expert,
                                      size=(num_experts, ),
                                      device="cuda",
                                      dtype=torch.int32)

    A, A_q, A_scale = make_quantized_test_activations(
        num_experts,
        max_tokens_per_expert,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
133
134
        per_act_token_quant=per_act_token_quant,
    )
bnellnm's avatar
bnellnm committed
135

136
    (B, B_q, B_scale, _), _ = make_test_weights(
bnellnm's avatar
bnellnm committed
137
138
139
140
141
142
        num_experts,
        N // 2,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
143
        per_out_ch_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
144
145
146
147
148
149
    )

    out_shape = (num_experts, max_tokens_per_expert, N)
    test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
150
151
152
153
154
155

    compute_tl_dtype = {
        torch.float16: tl.float16,
        torch.bfloat16: tl.bfloat16,
        torch.float32: tl.float32
    }[test_output.dtype]
bnellnm's avatar
bnellnm committed
156
157
158

    assert A_q.dtype == B_q.dtype

159
    invoke_moe_batched_triton_kernel(
bnellnm's avatar
bnellnm committed
160
161
        A_q,
        B_q,
162
        test_output,
bnellnm's avatar
bnellnm committed
163
        num_expert_tokens,
164
165
        compute_tl_dtype,
        # Quantization data
bnellnm's avatar
bnellnm committed
166
167
        A_scale,
        B_scale,
168
169
        None,
        # Quantization schemes
bnellnm's avatar
bnellnm committed
170
        use_fp8_w8a8,
171
172
173
174
175
        False,
        False,
        config={
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 16,
bnellnm's avatar
bnellnm committed
176
177
            "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
        },
178
        per_act_token_quant=per_act_token_quant,
bnellnm's avatar
bnellnm committed
179
180
        block_shape=block_shape,
    )
181

bnellnm's avatar
bnellnm committed
182
183
184
185
186
187
188
189
190
191
    ref_output = native_batched_masked_quant_matmul(
        A,
        B,
        ref_output,
        num_expert_tokens,
    )

    q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
                                                      num_expert_tokens,
                                                      A_scale, B_scale,
192
193
                                                      block_shape,
                                                      per_act_token_quant)
194
195
196
197
198
199
200

    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[test_output.dtype]

201
    torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
bnellnm's avatar
bnellnm committed
202
203
204
205
206
207
    torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
208
209
210
211
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("input_scales", [False])
bnellnm's avatar
bnellnm committed
212
213
214
215
216
217
218
219
220
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
221
    input_scales: bool,
bnellnm's avatar
bnellnm committed
222
223
224
225
226
):
    current_platform.seed_everything(7)

    use_fp8_w8a8 = dtype == torch.float8_e4m3fn

227
228
229
    if topk > e:
        pytest.skip("topk > e")

bnellnm's avatar
bnellnm committed
230
231
232
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

233
    if per_act_token_quant and block_shape is not None:
bnellnm's avatar
bnellnm committed
234
235
236
237
238
239
240
241
242
243
244
245
        pytest.skip("Skip illegal quantization test.")

    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

246
    (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
247
248
249
250
251
252
        e,
        n,
        k,
        block_shape=block_shape,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
253
        per_out_ch_quant=per_act_token_quant,
254
255
256
257
258
259
260
261
    )

    if input_scales and quant_dtype is not None:
        a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None
bnellnm's avatar
bnellnm committed
262
263
264

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
265
266

        baseline_output = torch_experts(
bnellnm's avatar
bnellnm committed
267
268
269
270
271
272
273
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
274
275
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
276
277
278
279
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )
280
281

        batched_output = naive_batched_moe(
bnellnm's avatar
bnellnm committed
282
283
284
285
286
287
288
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
289
290
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
291
292
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
293
294
            block_shape=block_shape,
        )
bnellnm's avatar
bnellnm committed
295

296
        triton_output = batched_moe(
bnellnm's avatar
bnellnm committed
297
298
299
300
301
302
303
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
304
305
            a1_scale=a1_scale,
            a2_scale=a2_scale,
bnellnm's avatar
bnellnm committed
306
307
308
309
310
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )

311
    torch.testing.assert_close(batched_output,
bnellnm's avatar
bnellnm committed
312
                               baseline_output,
313
                               atol=3e-2,
bnellnm's avatar
bnellnm committed
314
315
316
317
318
319
                               rtol=2e-2)

    torch.testing.assert_close(triton_output,
                               batched_output,
                               atol=2e-2,
                               rtol=2e-2)