test_moe.py 16.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

10
import vllm.model_executor.layers.fused_moe  # noqa
11
12
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
                                 torch_moe, torch_moe_single)
13
from vllm import _custom_ops as ops
14
from vllm.model_executor.layers.fused_moe import fused_moe
15
16
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_topk, moe_align_block_size)
17
18
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
    fused_moe as iterative_moe)
19
20
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
    marlin_quantize)
21
22
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    quantize_weights)
23
from vllm.model_executor.models.mixtral import MixtralMoE
24
from vllm.platforms import current_platform
25
26
from vllm.scalar_type import scalar_types

27
28
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
29

30
31
32

@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
33
@pytest.mark.parametrize("k", [128, 511, 1024])
34
35
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
36
37
38
39
40
41
42
43
44
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
):
45
46
47
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
48

49
    score = torch.randn((m, e), device="cuda", dtype=dtype)
50
51
    triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
    torch_output = torch_moe(a, w1, w2, score, topk)
52
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
53
54
55
56
57
    iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
    torch.testing.assert_close(iterative_output,
                               torch_output,
                               atol=2e-2,
                               rtol=0)
58
59


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
                        dtype: torch.dtype, group_size: int, has_zp: bool,
                        weight_bits: int):
    print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
    w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w2_qweight = torch.empty((e, k, n // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size),
                            device="cuda",
                            dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size),
                            device="cuda",
                            dtype=dtype)
    w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
                            device="cuda",
                            dtype=torch.uint8)
    w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
                            device="cuda",
                            dtype=torch.uint8)

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
        else:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
        weight, qweight, scales, qzeros = quantize_weights(
            w[expert_id].T, quant_type, group_size, has_zp, False)
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

    triton_output = fused_moe(a,
                              w1_qweight,
                              w2_qweight,
                              score,
                              topk,
                              renormalize=False,
                              use_int4_w4a16=weight_bits == 4,
                              use_int8_w8a16=weight_bits == 8,
                              w1_scale=w1_scales,
                              w2_scale=w2_scales,
                              w1_zp=w1_qzeros if has_zp else None,
                              w2_zp=w2_qzeros if has_zp else None,
                              block_shape=[0, group_size])
    torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


149
150
151
152
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
153
154
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
155
156
157
158
159
160
161
162
163
164
165

    # Instantiate our and huggingface's MoE blocks
    config = MixtralConfig()
    hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
    vllm_moe = MixtralMoE(
        num_experts=config.num_local_experts,
        top_k=config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        params_dtype=dtype,
        tp_size=1,
166
    ).cuda()
167
168

    # Load the weights
169
    vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
170
171
172
    for i in range(config.num_local_experts):
        weights = (hf_moe.experts[i].w1.weight.data,
                   hf_moe.experts[i].w3.weight.data)
173
174
        vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
        vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
175
176

    # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
177
178
179
    hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
    # vLLM uses 1D query [num_tokens, hidden_dim]
    vllm_inputs = hf_inputs.flatten(0, 1)
180
181

    # Run forward passes for both MoE blocks
182
183
    hf_states, _ = hf_moe.forward(hf_inputs)
    vllm_states = vllm_moe.forward(vllm_inputs)
184
185
186
187
188
189

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }
190

191
192
193
194
    torch.testing.assert_close(hf_states.flatten(0, 1),
                               vllm_states,
                               rtol=mixtral_moe_tol[dtype],
                               atol=mixtral_moe_tol[dtype])
195
196


zhuwenwen's avatar
zhuwenwen committed
197
@pytest.mark.skipif(current_platform.is_rocm(),
198
                    reason="Currently, there is not supported on ROCm.")
199
200
201
202
203
204
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("group_size", [-1, 32, 128])
205
@pytest.mark.parametrize("act_order", [True, False])
206
@pytest.mark.parametrize("num_bits", [4, 8])
207
@pytest.mark.parametrize("is_k_full", [True, False])
208
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
209
210
211
212
213
214
215
216
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    group_size: int,
    act_order: bool,
217
    num_bits: int,
218
    is_k_full: bool,
219
):
220
    current_platform.seed_everything(7)
221
222
223
224
225
226
227

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size in (k, n):
            return
228
229
230
    else:
        if not is_k_full:
            return
231

232
233
    quant_type = (scalar_types.uint4b8
                  if num_bits == 4 else scalar_types.uint8b128)
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    dtype = torch.float16
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
        test_perm = torch.randperm(k)
        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
            w1[i].transpose(1, 0), quant_type, group_size, act_order,
            test_perm)
        w_ref1_l.append(w_ref1)
        qweight1_l.append(qweight1)
        scales1_l.append(scales1)
        g_idx1_l.append(g_idx1)
        sort_indices1_l.append(sort_indices1)

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
    g_idx1 = stack_and_dev(g_idx1_l)
    sort_indices1 = stack_and_dev(sort_indices1_l)

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
        test_perm = torch.randperm(n)
        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
            w2[i].transpose(1, 0), quant_type, group_size, act_order,
            test_perm)
        w_ref2_l.append(w_ref2)
        qweight2_l.append(qweight2)
        scales2_l.append(scales2)
        g_idx2_l.append(g_idx2)
        sort_indices2_l.append(sort_indices2)

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
    g_idx2 = stack_and_dev(g_idx2_l)
    sort_indices2 = stack_and_dev(sort_indices2_l)

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids = fused_topk(a, score, topk, False)

    triton_output = fused_moe(
        a,
        w_ref1.transpose(1, 2).contiguous(),
        w_ref2.transpose(1, 2).contiguous(),
        score,
        topk,
        renormalize=False,
    )
297
    marlin_output = torch.ops.vllm.fused_marlin_moe(
298
299
300
        a,
        qweight1,
        qweight2,
301
302
        scales1,
        scales2,
303
304
305
        score,
        topk_weights,
        topk_ids,
306
307
308
309
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
310
        num_bits=num_bits,
311
        is_k_full=is_k_full,
312
313
314
315
    )

    assert compute_max_diff(marlin_output, triton_output) < 4e-2

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    if ops.supports_moe_ops:
        token_expert_indicies = torch.empty(m,
                                            topk,
                                            dtype=torch.int32,
                                            device=a.device)

        opcheck(torch.ops._moe_C.topk_softmax, (
            topk_weights,
            topk_ids,
            token_expert_indicies,
            score.float(),
        ))

        block_size_m = 4

        sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
                                                      e)

        max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
        workspace = torch.zeros(max_workspace_size,
                                dtype=torch.int,
                                device="cuda",
                                requires_grad=False)

340
341
342
343
        zp = torch.empty((0, 0),
                         dtype=dtype,
                         device="cuda",
                         requires_grad=False)
344
345
        opcheck(torch.ops._moe_C.marlin_gemm_moe,
                (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
346
347
                 scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
                 m, 2 * n, k, True, e, topk, block_size_m, True, False))
zhuwenwen's avatar
zhuwenwen committed
348
        
349

zhuwenwen's avatar
zhuwenwen committed
350
@pytest.mark.skipif(current_platform.is_rocm(),
351
                    reason="Currently, there is not supported on ROCm.")
352
353
354
355
356
@pytest.mark.skip("This test is here for the sake of debugging, "
                  "don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
357
@pytest.mark.parametrize("e", [8, 64])
358
359
360
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
361
@pytest.mark.parametrize("num_bits", [4, 8])
362
@pytest.mark.parametrize("is_k_full", [True, False])
363
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
364
def test_single_marlin_moe_multiply(
365
366
367
368
369
370
371
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    group_size: int,
    act_order: bool,
372
    num_bits: int,
373
    is_k_full: bool,
374
375
376
377
378
379
380
381
):

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == k:
            return
382
383
384
    else:
        if not is_k_full:
            return
385

386
387
    quant_type = (scalar_types.uint4b8
                  if num_bits == 4 else scalar_types.uint8b128)
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    dtype = torch.float16
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

    w_ref_l = []
    qweights_l = []
    scales_l = []
    g_idx_l = []
    sort_indices_l = []

    for i in range(w.shape[0]):
        test_perm = torch.randperm(k)
        w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
            w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
        w_ref_l.append(w_ref)
        qweights_l.append(qweight)
        scales_l.append(scales)
        g_idx_l.append(g_idx)
        sort_indices_l.append(sort_indices)

    w_ref = stack_and_dev(w_ref_l)
    qweight = stack_and_dev(qweights_l).contiguous()
    scales = stack_and_dev(scales_l)
    g_idx = stack_and_dev(g_idx_l)
    sort_indices = stack_and_dev(sort_indices_l)

    score = torch.randn((m, e), device="cuda", dtype=dtype)
415
    marlin_output = torch.ops.vllm.single_marlin_moe(
416
417
418
419
420
421
        a,
        qweight,
        scales,
        score,
        topk,
        renormalize=False,
422
423
        g_idx=g_idx,
        sort_indices=sort_indices,
424
425
426
        num_bits=num_bits,
        is_k_full=is_k_full,
    )
427

428
429
430
    torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

    assert compute_max_diff(marlin_output, torch_output) < 1e-2
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
    topk_ids = torch.randint(0,
                             num_experts, (3, 4),
                             dtype=torch.int32,
                             device='cuda')

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty((max_num_tokens_padded, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    num_tokens_post_pad = torch.empty((1),
                                      dtype=torch.int32,
                                      device=topk_ids.device)

454
    opcheck(torch.ops._moe_C.moe_align_block_size,
455
            (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
zhuwenwen's avatar
zhuwenwen committed
456
             num_tokens_post_pad))