"vllm/vscode:/vscode.git/clone" did not exist on "bd45912b99e3bad6621fd4d6bc103352ff31bcb7"
test_moe.py 30.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
7

8
import functools
9
from collections.abc import Callable
10

11
12
import pytest
import torch
13
14
from torch.nn import Parameter
from torch.nn import functional as F
15
16
17
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

18
import vllm.model_executor.layers.fused_moe  # noqa
19
from tests.kernels.moe.utils import fused_moe
20
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
21
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
22
from vllm.distributed.parallel_state import init_distributed_environment
23
from vllm.forward_context import set_forward_context
24
from vllm.model_executor.layers.fused_moe.config import (
25
26
27
28
    FUSED_MOE_UNQUANTIZED_CONFIG,
    int4_w4a16_moe_quant_config,
    int8_w8a16_moe_quant_config,
)
29
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
30
from vllm.model_executor.layers.fused_moe.fused_moe import (
31
32
33
    fused_topk,
    modular_triton_fused_moe,
)
34
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
35
36
    fused_moe as iterative_moe,
)
37
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
38
39
    marlin_permute_bias,
)
40
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
41
42
43
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
44
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
45
46
    marlin_quant_fp8_torch,
)
47
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
48
49
50
51
    awq_marlin_quantize,
    marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
52
from vllm.model_executor.models.mixtral import MixtralMoE
53
from vllm.platforms import current_platform
54
from vllm.scalar_type import ScalarType, scalar_types
55

56
NUM_EXPERTS = [8, 64, 192]
57
EP_SIZE = [1, 4]
58
TOP_KS = [2, 6]
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
FUSED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (1, 2048, 128),
    (33, 2048, 128),
    (222, 1024, 1024),
    (32768, 128, 128),
    (32768, 2048, 511),
    (40000, 1024, 1024),
]

FUSED_MOE_WN16_MNK_FACTORS = [
    (1, 128, 128),
    (1, 1024, 1024),
    (32, 2048, 128),
    (32, 1024, 1024),
    (222, 2048, 1024),
]

78
79
80
81
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

82

83
def run_moe_test(
84
    baseline: Callable | torch.Tensor,
85
86
87
88
89
90
91
    moe_fn: Callable,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    global_num_experts: int = -1,
92
    expert_map: torch.Tensor | None = None,
93
94
95
96
97
98
99
100
101
    padding: bool = False,
    use_compile: bool = False,
    use_cudagraph: bool = False,
    atol: float = 2e-2,
    rtol: float = 0,
) -> torch.Tensor:
    if isinstance(baseline, torch.Tensor):
        baseline_output = baseline
    else:
102
103
104
105
106
107
108
109
110
        baseline_output = baseline(
            a,
            w1,
            w2,
            score,
            topk,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
111
112
113
114
115
116
117
118
119
120
121

    # Pad the weight if moe padding is enabled
    if padding:
        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]

    if use_compile:
        moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(score, 0)

122
123
124
125
126
127
128
129
130
    test_output = moe_fn(
        a,
        w1,
        w2,
        score,
        topk,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
    )
131
132
133
134
135
136

    if use_cudagraph:
        test_output.fill_(0)
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
137
138
139
140
141
142
143
144
145
            test_output = moe_fn(
                a,
                w1,
                w2,
                score,
                topk,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
            )
146
147
148
149
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

150
    torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
151
152
153
154

    return baseline_output


155
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
156
157
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
158
@pytest.mark.parametrize("ep_size", EP_SIZE)
159
@pytest.mark.parametrize("dtype", [torch.bfloat16])
160
@pytest.mark.parametrize("padding", [True, False])
161
@pytest.mark.parametrize("chunk_size", [8192])
162
163
164
165
166
167
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
168
    ep_size: int,
169
    dtype: torch.dtype,
170
    padding: bool,
171
172
    chunk_size: int,
    monkeypatch,
173
):
174
175
176
177
178
179
180
181
    current_platform.seed_everything(7)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

    #
    # Setup test data
    #

bnellnm's avatar
bnellnm committed
182
183
184
185
    #
    # Setup test data
    #

186
187
188
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
189

190
    score = torch.randn((m, e), device="cuda", dtype=dtype)
191
192
193

    if ep_size > 1:
        local_e = e // ep_size
194
195
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
196
197
198
199
200
201
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

202
203
204
    #
    # Setup test functions
    #
205
    quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
206

207
    m_fused_moe_fn = modular_triton_fused_moe(quant_config)
208
209
210
211
212
213
214
215

    def m_fused_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        score: torch.Tensor,
        topk: int,
        global_num_experts: int = -1,
216
        expert_map: torch.Tensor | None = None,
217
218
    ) -> torch.Tensor:
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
219
220
221
222
223
224
225
226
227
        return m_fused_moe_fn(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    fused_moe_fn = functools.partial(fused_moe, renormalize=False)

    #
    # Run tests
    #
    runner = functools.partial(
        run_moe_test,
        a=a,
        w1=w1,
        w2=w2,
        score=score,
        topk=topk,
        global_num_experts=e,
        expert_map=e_map,
        padding=padding,
    )
245

246
247
248
249
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False
250

251
    use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
252

253
254
    with set_current_vllm_config(vllm_config):
        baseline_output = runner(torch_moe, iterative_moe)
255
256
257
258
259
260
261
262
263
264
265
266
        runner(
            baseline_output,
            fused_moe_fn,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
        runner(
            baseline_output,
            m_fused_moe,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
267
268


269
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
270
271
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
272
@pytest.mark.parametrize("ep_size", EP_SIZE)
273
@pytest.mark.parametrize("dtype", [torch.bfloat16])
274
275
276
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
277
278
279
280
281
282
283
284
285
286
287
288
def test_fused_moe_wn16(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    ep_size: int,
    dtype: torch.dtype,
    group_size: int,
    has_zp: bool,
    weight_bits: int,
):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
303
304
305
306
307
308
309
310
311
312
313
314
    w1_qweight = torch.empty(
        (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
    )
    w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
    w1_qzeros = torch.empty(
        (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
    )
    w2_qzeros = torch.empty(
        (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
    )
315
316
317
318

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
319
320
321
322
323
324
325
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w1,
                w1_ref,
                w1_qweight,
                w1_scales,
                w1_qzeros,
            )
326
        else:
327
328
329
330
331
332
333
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w2,
                w2_ref,
                w2_qweight,
                w2_scales,
                w2_qzeros,
            )
334
        weight, qweight, scales, qzeros = quantize_weights(
335
336
            w[expert_id].T, quant_type, group_size, has_zp, False
        )
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

353
354
    if ep_size > 1:
        local_e = e // ep_size
355
356
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
357
358
359
360
361
362
363
364
365
366
367
368
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

369
370
371
372
373
374
    if weight_bits == 4:
        quant_config_builder = int4_w4a16_moe_quant_config
    else:
        assert weight_bits == 8
        quant_config_builder = int8_w8a16_moe_quant_config

375
376
377
378
379
380
381
    quant_config = quant_config_builder(
        w1_scale=w1_scales,
        w2_scale=w2_scales,
        w1_zp=w1_qzeros if has_zp else None,
        w2_zp=w2_qzeros if has_zp else None,
        block_shape=[0, group_size],
    )
382

383
    with set_current_vllm_config(vllm_config):
384
385
386
387
388
389
390
391
392
393
394
395
        triton_output = fused_moe(
            a,
            w1_qweight,
            w2_qweight,
            score,
            topk,
            renormalize=False,
            global_num_experts=e,
            expert_map=e_map,
            quant_config=quant_config,
        )
        torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
396

397
398
399
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


400
@pytest.mark.parametrize("dtype", [torch.bfloat16])
401
@pytest.mark.parametrize("padding", [True, False])
402
@pytest.mark.parametrize(
403
404
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
405
@torch.inference_mode()
406
407
408
def test_mixtral_moe(
    dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
409
410
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
411

412
413
    # clear the cache before every test
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
414
415
416
        is_rocm_aiter_moe_enabled,
    )

417
    is_rocm_aiter_moe_enabled.cache_clear()
418
419
420
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

421
422
423
        if dtype == torch.float32:
            pytest.skip("AITER ROCm test skip for float32")

424
425
426
427
428
    monkeypatch.setenv("RANK", "0")
    monkeypatch.setenv("LOCAL_RANK", "0")
    monkeypatch.setenv("WORLD_SIZE", "1")
    monkeypatch.setenv("MASTER_ADDR", "localhost")
    monkeypatch.setenv("MASTER_PORT", "12345")
bnellnm's avatar
bnellnm committed
429
430
    init_distributed_environment()

431
    # Instantiate our and huggingface's MoE blocks
432
    vllm_config.compilation_config.static_forward_context = dict()
433
    with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        config = MixtralConfig()
        hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
        vllm_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            params_dtype=dtype,
            tp_size=1,
            dp_size=1,
        ).cuda()

        # Load the weights
        vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
        for i in range(config.num_local_experts):
449
450
451
452
            weights = (
                hf_moe.experts[i].w1.weight.data,
                hf_moe.experts[i].w3.weight.data,
            )
453
454
455
456
            vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
            vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data

        # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
457
        hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
458
459
        # vLLM uses 1D query [num_tokens, hidden_dim]
        vllm_inputs = hf_inputs.flatten(0, 1)
460

461
462
        # Pad the weight if moe padding is enabled
        if padding:
463
464
465
466
467
468
469
470
471
472
            vllm_moe.experts.w13_weight = Parameter(
                F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
                    ..., 0:-128
                ],
                requires_grad=False,
            )
            vllm_moe.experts.w2_weight = Parameter(
                F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
                requires_grad=False,
            )
473
            torch.cuda.synchronize()
474
475
476
477
478
            torch.cuda.empty_cache()

        # Run forward passes for both MoE blocks
        hf_states, _ = hf_moe.forward(hf_inputs)
        vllm_states = vllm_moe.forward(vllm_inputs)
479
480
481
482
483
484
485

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

486
    if use_rocm_aiter:
487
488
        # The values of rtol and atol are set based on the tests in ROCM AITER package.
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174
489
490
491
        torch.testing.assert_close(
            hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
        )
492
    else:
493
494
495
496
497
498
        torch.testing.assert_close(
            hf_states.flatten(0, 1),
            vllm_states,
            rtol=mixtral_moe_tol[dtype],
            atol=mixtral_moe_tol[dtype],
        )
499
500


501
502
def marlin_moe_generate_valid_test_cases():
    import itertools
503

504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
    m_list = [1, 123, 666]
    n_list = [128, 1024]
    k_list = [256, 2048]
    e_list = [4, 12]
    topk_list = [2, 3]
    ep_size_list = [1, 4]
    dtype_list = [torch.half, torch.bfloat16]
    group_size_list = [-1, 16, 32, 128]
    act_order_list = [True, False]
    quant_type_list = [
        scalar_types.float4_e2m1f,
        scalar_types.float8_e4m3fn,
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.uint8b128,
    ]
    is_k_full_list = [True, False]

522
523
524
525
526
527
528
529
530
531
532
533
534
    all_combinations = itertools.product(
        m_list,
        n_list,
        k_list,
        e_list,
        topk_list,
        ep_size_list,
        dtype_list,
        group_size_list,
        act_order_list,
        quant_type_list,
        is_k_full_list,
    )
535

536
537
538
539
    def is_invalid(
        m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
    ):
        if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
540
            return False
541
542
543
544
545
        if quant_type == scalar_types.float4_e2m1f:
            if group_size not in [16, 32]:
                return False
            if dtype == torch.float16 and group_size == 32:
                return False
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        if quant_type != scalar_types.float4_e2m1f and group_size == 16:
            return False

        # Filter act_order
        if act_order:
            if group_size in (-1, k, n):
                return False
            if quant_type not in [scalar_types.uint4b8]:
                return False
        elif not is_k_full:
            return False

        return True

    cases = []
    for case in all_combinations:
        if is_invalid(*case):
            cases.append(case)
    return cases


567
@pytest.mark.flaky(reruns=2)
568
569
570
571
@pytest.mark.parametrize(
    ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
    marlin_moe_generate_valid_test_cases(),
)
572
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
573
574
575
576
577
578
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
579
580
    ep_size: int,
    dtype: torch.dtype,
581
582
    group_size: int,
    act_order: bool,
583
    quant_type: ScalarType,
584
    is_k_full: bool,
585
):
586
587
588
    torch.cuda.manual_seed(0)
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]

589
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
590
591
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
592

593
594
595
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
596
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
597
598
599
600
601
602
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

603
604
605
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
606
    global_scale1_l = []
607
    zeros1_l = []
608
609
610
611
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
612
        if quant_type == scalar_types.float4_e2m1f:
613
            if group_size == 16:
614
                w_ref1, qweight1, scales1, global_scale1 = (
615
                    rand_marlin_weight_nvfp4_like(w1[i], group_size)
616
                )
617
            else:
618
619
620
                w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
                    w1[i], group_size
                )
621
                global_scale1 = None
622
623
624
625

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
626
627
            if global_scale1 is not None:
                global_scale1_l.append(global_scale1)
628
        elif quant_type == scalar_types.float8_e4m3fn:
629
            w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
630
631
632
633
            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
        elif has_zp:
634
            w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
635
636
                w1[i].transpose(1, 0), quant_type, group_size
            )
637
638
639
640
641

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            zeros1_l.append(zeros1)
642
        else:
643
            test_perm = torch.randperm(k)
644
645
646
            w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
                w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
            )
647
648
649
650
651
652

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            g_idx1_l.append(g_idx1)
            sort_indices1_l.append(sort_indices1)
653
654
655
656

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
657
    global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
658
659
660
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
661
662
663
664

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
665
    global_scale2_l = []
666
    zeros2_l = []
667
668
669
670
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
671
        if quant_type == scalar_types.float4_e2m1f:
672
            if group_size == 16:
673
                w_ref2, qweight2, scales2, global_scale2 = (
674
                    rand_marlin_weight_nvfp4_like(w2[i], group_size)
675
                )
676
            else:
677
678
679
                w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
                    w2[i], group_size
                )
680
                global_scale2 = None
681
682
683
684

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
685
686
            if global_scale2 is not None:
                global_scale2_l.append(global_scale2)
687
        elif quant_type == scalar_types.float8_e4m3fn:
688
            w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
689
690
691
692
            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
        elif has_zp:
693
            w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
694
695
                w2[i].transpose(1, 0), quant_type, group_size
            )
696
697
698
699
700

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            zeros2_l.append(zeros2)
701
        else:
702
            test_perm = torch.randperm(n)
703
704
705
            w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
                w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
            )
706
707
708
709
710
711

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            g_idx2_l.append(g_idx2)
            sort_indices2_l.append(sort_indices2)
712
713
714
715

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
716
    global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
717
718
719
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
720
721
722

    score = torch.randn((m, e), device="cuda", dtype=dtype)

723
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
724

725
    with set_current_vllm_config(vllm_config):
726
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
727

728
    marlin_output = fused_marlin_moe(
729
730
731
        a,
        qweight1,
        qweight2,
732
733
        None,
        None,
734
735
        scales1,
        scales2,
736
737
738
        score,
        topk_weights,
        topk_ids,
739
740
        global_num_experts=e,
        expert_map=e_map,
741
742
        global_scale1=global_scale1,
        global_scale2=global_scale2,
743
744
        g_idx1=g_idx1,
        g_idx2=g_idx2,
745
746
747
748
749
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
        w1_zeros=zeros1,
        w2_zeros=zeros2,
        quant_type_id=quant_type.id,
750
751
        is_k_full=is_k_full,
    )
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784

    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
    torch.cuda.manual_seed(0)

    e, topk = 32, 4
    n, k = 2048, 2048
    group_size = 128
    act_order = False
    is_k_full = True
    quant_type = scalar_types.uint4b8
    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
    b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10

    b_bias1_l = []
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
        test_perm = torch.randperm(k)
785
786
787
        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
            w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
        )
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813

        w_ref1_l.append(w_ref1.T)
        qweight1_l.append(qweight1)
        scales1_l.append(scales1)
        g_idx1_l.append(g_idx1)
        sort_indices1_l.append(sort_indices1)
        b_bias1_l.append(marlin_permute_bias(b_bias1[i]))

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
    global_scale1 = None
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
    marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None

    b_bias2_l = []
    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
        test_perm = torch.randperm(n)
814
815
816
        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
            w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
        )
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838

        w_ref2_l.append(w_ref2.T)
        qweight2_l.append(qweight2)
        scales2_l.append(scales2)
        g_idx2_l.append(g_idx2)
        sort_indices2_l.append(sort_indices2)
        b_bias2_l.append(marlin_permute_bias(b_bias2[i]))

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
    global_scale2 = None
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
    marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

    with set_current_vllm_config(vllm_config):
839
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
840

841
    marlin_output = fused_marlin_moe(
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
        a,
        qweight1,
        qweight2,
        marlin_bias1,
        marlin_bias2,
        scales1,
        scales2,
        score,
        topk_weights,
        topk_ids,
        global_num_experts=e,
        expert_map=None,
        global_scale1=global_scale1,
        global_scale2=global_scale2,
        g_idx1=g_idx1,
        g_idx2=g_idx2,
858
859
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
860
861
        w1_zeros=zeros1,
        w2_zeros=zeros2,
862
        quant_type_id=quant_type.id,
863
864
        is_k_full=is_k_full,
    )
865

866
    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
867
868
869
870
871


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
872
    topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
873
874

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
875
876
877
    sorted_ids = torch.empty(
        (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
    )
878
879
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
    expert_ids = torch.empty(
        (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
    )
    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

    opcheck(
        torch.ops._moe_C.moe_align_block_size,
        (
            topk_ids,
            num_experts,
            block_size,
            sorted_ids,
            expert_ids,
            num_tokens_post_pad,
        ),
    )
bnellnm's avatar
bnellnm committed
896
897


898
@pytest.mark.parametrize("m", [1, 33, 64, 222])
bnellnm's avatar
bnellnm committed
899
900
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
901
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
bnellnm's avatar
bnellnm committed
902
903
904
905
906
907
908
909
910
911
912
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
    input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
    actual = torch.empty((m, k), device="cuda", dtype=dtype)

    expected = input.sum(dim=1)
    torch.ops._moe_C.moe_sum(input, actual)

    torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)

    opcheck(torch.ops._moe_C.moe_sum, (input, actual))
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981


@pytest.mark.parametrize("m", [1, 33])
@pytest.mark.parametrize("n,k", [(128, 128)])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
    from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE

    device = "cpu"
    torch.manual_seed(7)

    a = torch.randn((m, k), device=device, dtype=dtype) / 10
    w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
    router_logits = torch.randn((m, e), device=device, dtype=dtype)

    b1 = b2 = None
    if with_bias:
        b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
        b2 = torch.randn((e, k), device=device, dtype=dtype) / 10

    ref = (
        torch_moe(a, w13, w2, router_logits, topk, b1, b2)
        if with_bias
        else torch_moe(a, w13, w2, router_logits, topk)
    )

    class _Dummy(torch.nn.Module):
        def __init__(self, w13, w2, b1=None, b2=None):
            super().__init__()
            self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
            self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
            if b1 is not None:
                self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
            if b2 is not None:
                self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)

    layer = _Dummy(w13, w2, b1, b2).to(dtype)
    fused = CPUFusedMOE(layer)
    out = fused(
        layer=layer,
        x=a,
        use_grouped_topk=False,
        top_k=topk,
        router_logits=router_logits,
        renormalize=False,
        global_num_experts=e,
        expert_map=None,
        custom_routing_function=None,
        scoring_func="softmax",
        routed_scaling_factor=1.0,
        e_score_correction_bias=None,
        apply_router_weight_on_input=False,
        activation=activation,
    )

    # Tolerances: fp32 tight; bf16 looser (esp. with bias)
    if dtype == torch.float32:
        atol = 1e-3
    elif with_bias:
        atol = 8e-2
    else:
        atol = 5e-2
    torch.testing.assert_close(out, ref, atol=atol, rtol=0)