"vllm/vscode:/vscode.git/clone" did not exist on "c3fec47bb7fbed8ac1b8dbaa84f672443b33ed2f"
test_batched_moe.py 9.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4

from dataclasses import dataclass
bnellnm's avatar
bnellnm committed
5
from typing import Optional
6
7
8
9
10

import pytest
import torch
import triton.language as tl

bnellnm's avatar
bnellnm committed
11
12
13
14
15
16
from tests.kernels.moe.utils import (batched_moe,
                                     make_quantized_test_activations,
                                     make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
17
18
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
    invoke_moe_batched_triton_kernel)
bnellnm's avatar
bnellnm committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform

MNK_FACTORS = [
    (1, 128, 128),
    (1, 128, 2048),
    (1, 512, 512),
    (1, 1024, 128),
    (1, 1024, 2048),
    (32, 128, 128),
    (32, 512, 512),
    (32, 1024, 2048),
    (45, 128, 128),
    (45, 128, 2048),
    (45, 512, 512),
    (45, 1024, 128),
    (45, 1024, 2048),
    (64, 128, 128),
    (64, 512, 512),
    (64, 1024, 2048),
    (222, 128, 128),
    (222, 128, 2048),
    (222, 512, 512),
    (222, 1024, 128),
    (222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
51
52
53
54


@dataclass
class BatchedMMConfig:
bnellnm's avatar
bnellnm committed
55
56
57
    in_dtype: torch.dtype
    quant_dtype: Optional[torch.dtype]
    out_dtype: torch.dtype
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    num_experts: int
    max_tokens_per_expert: int
    K: int
    N: int


@dataclass
class BatchedMMTensors:
    A: torch.Tensor  # [E, max_tokens, K]
    B: torch.Tensor  # [E, K, N] - column major
    C: torch.Tensor  # [E, max_tokens, N]
    num_expert_tokens: torch.Tensor  # [E]

    @staticmethod
    def make_tensors(config: BatchedMMConfig):
        A = torch.randn(
            (config.num_experts, config.max_tokens_per_expert, config.K),
            device="cuda",
bnellnm's avatar
bnellnm committed
76
            dtype=config.in_dtype) / 10
77
78
        B = torch.randn((config.num_experts, config.N, config.K),
                        device="cuda",
bnellnm's avatar
bnellnm committed
79
                        dtype=config.in_dtype)
80
81
82
        C = torch.zeros(
            (config.num_experts, config.max_tokens_per_expert, config.N),
            device="cuda",
bnellnm's avatar
bnellnm committed
83
84
            dtype=config.out_dtype)

85
86
87
88
89
90
        num_expert_tokens = torch.randint(low=0,
                                          high=config.max_tokens_per_expert,
                                          size=(config.num_experts, ),
                                          device="cuda",
                                          dtype=torch.int32)

bnellnm's avatar
bnellnm committed
91
        return BatchedMMTensors(A, B, C, num_expert_tokens)
92
93


bnellnm's avatar
bnellnm committed
94
@pytest.mark.parametrize("num_experts", [8, 16, 32])
95
96
97
98
99
100
@pytest.mark.parametrize("max_tokens_per_expert",
                         [32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
bnellnm's avatar
bnellnm committed
101
102
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
103
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
bnellnm's avatar
bnellnm committed
104
105
106
107
                    N: int, dtype: torch.dtype,
                    block_shape: Optional[list[int]],
                    per_act_token_quant: bool):
    current_platform.seed_everything(7)
108

bnellnm's avatar
bnellnm committed
109
    use_fp8_w8a8 = dtype == torch.float8_e4m3fn
110

bnellnm's avatar
bnellnm committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
        pytest.skip("Don't test blocking for non-quantized types.")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization test.")

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

    num_expert_tokens = torch.randint(low=0,
                                      high=max_tokens_per_expert,
                                      size=(num_experts, ),
                                      device="cuda",
                                      dtype=torch.int32)

    A, A_q, A_scale = make_quantized_test_activations(
        num_experts,
        max_tokens_per_expert,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant)

    B, B_q, B_scale, _, _, _ = make_test_weights(
        num_experts,
        N // 2,
        K,
        in_dtype=act_dtype,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
    )

    out_shape = (num_experts, max_tokens_per_expert, N)
    test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
    q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
152
153
154
155
156
157

    compute_tl_dtype = {
        torch.float16: tl.float16,
        torch.bfloat16: tl.bfloat16,
        torch.float32: tl.float32
    }[test_output.dtype]
bnellnm's avatar
bnellnm committed
158
159
160

    assert A_q.dtype == B_q.dtype

161
    invoke_moe_batched_triton_kernel(
bnellnm's avatar
bnellnm committed
162
163
        A_q,
        B_q,
164
        test_output,
bnellnm's avatar
bnellnm committed
165
        num_expert_tokens,
166
167
        compute_tl_dtype,
        # Quantization data
bnellnm's avatar
bnellnm committed
168
169
        A_scale,
        B_scale,
170
171
        None,
        # Quantization schemes
bnellnm's avatar
bnellnm committed
172
        use_fp8_w8a8,
173
174
175
176
177
        False,
        False,
        config={
            "BLOCK_SIZE_M": 16,
            "BLOCK_SIZE_N": 16,
bnellnm's avatar
bnellnm committed
178
179
180
181
            "BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
        },
        block_shape=block_shape,
    )
182

bnellnm's avatar
bnellnm committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    ref_output = native_batched_masked_quant_matmul(
        A,
        B,
        ref_output,
        num_expert_tokens,
        None,
        None,
        None,
    )

    q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
                                                      num_expert_tokens,
                                                      A_scale, B_scale,
                                                      block_shape)
197
198
199
200
201
202
203

    rtol, atol = {
        torch.float16: (6e-2, 6e-2),
        torch.bfloat16: (6e-2, 6e-2),
        torch.float32: (1e-2, 1e-2),
    }[test_output.dtype]

bnellnm's avatar
bnellnm committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
    torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)


@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
):
    current_platform.seed_everything(7)

    use_fp8_w8a8 = dtype == torch.float8_e4m3fn

    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

    if per_act_token_quant and block_shape is not None or topk > e:
        pytest.skip("Skip illegal quantization test.")

    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

    if dtype.itemsize == 1:
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        act_dtype = dtype
        quant_dtype = None

    _, w1, w1_s, _, w2, w2_s = make_test_weights(e,
                                                 n,
                                                 k,
                                                 block_shape=block_shape,
                                                 in_dtype=act_dtype,
                                                 quant_dtype=quant_dtype)

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
        batched_output = batched_moe(
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )
        baseline_output = torch_experts(
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape)

        triton_output = triton_moe(
            a,
            w1,
            w2,
            topk_weight,
            topk_ids,
            w1_scale=w1_s,
            w2_scale=w2_s,
            quant_dtype=quant_dtype,
            per_act_token_quant=per_act_token_quant,
            block_shape=block_shape,
        )

    torch.testing.assert_close(triton_output,
                               baseline_output,
                               atol=2e-2,
                               rtol=2e-2)

    torch.testing.assert_close(triton_output,
                               batched_output,
                               atol=2e-2,
                               rtol=2e-2)