"tests/fault_tolerance/deploy/container/Dockerfile.local_vllm" did not exist on "e01c6e99bd5d9d0edf32560e274250d7c32e560f"
test_moe.py 36.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
7

8
import functools
9
from collections.abc import Callable
10
11
from dataclasses import dataclass
from typing import Any
12

13
14
import pytest
import torch
15
16
from torch.nn import Parameter
from torch.nn import functional as F
17
18
19
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

20
import vllm.model_executor.layers.fused_moe  # noqa
21
from tests.kernels.moe.utils import fused_moe
22
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
23
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
24
from vllm.distributed.parallel_state import init_distributed_environment
25
from vllm.forward_context import set_forward_context
26
from vllm.model_executor.layers.fused_moe.config import (
27
28
29
30
    FUSED_MOE_UNQUANTIZED_CONFIG,
    int4_w4a16_moe_quant_config,
    int8_w8a16_moe_quant_config,
)
31
32
33
34
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
    batched_fused_marlin_moe,
    fused_marlin_moe,
)
35
from vllm.model_executor.layers.fused_moe.fused_moe import (
36
37
38
    fused_topk,
    modular_triton_fused_moe,
)
39
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
40
41
    fused_moe as iterative_moe,
)
42
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
43
44
    marlin_permute_bias,
)
45
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
46
47
48
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
49
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
50
51
    marlin_quant_fp8_torch,
)
52
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
53
54
55
56
    awq_marlin_quantize,
    marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
57
from vllm.model_executor.models.mixtral import MixtralMoE
58
from vllm.platforms import current_platform
59
from vllm.scalar_type import ScalarType, scalar_types
60

61
NUM_EXPERTS = [8, 64, 192]
62
EP_SIZE = [1, 4]
63
TOP_KS = [2, 6]
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
FUSED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (1, 2048, 128),
    (33, 2048, 128),
    (222, 1024, 1024),
    (32768, 128, 128),
    (32768, 2048, 511),
    (40000, 1024, 1024),
]

FUSED_MOE_WN16_MNK_FACTORS = [
    (1, 128, 128),
    (1, 1024, 1024),
    (32, 2048, 128),
    (32, 1024, 1024),
    (222, 2048, 1024),
]

83
84
85
86
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

87

88
def run_moe_test(
89
    baseline: Callable | torch.Tensor,
90
91
92
93
94
95
96
    moe_fn: Callable,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    global_num_experts: int = -1,
97
    expert_map: torch.Tensor | None = None,
98
99
100
101
102
103
104
105
106
    padding: bool = False,
    use_compile: bool = False,
    use_cudagraph: bool = False,
    atol: float = 2e-2,
    rtol: float = 0,
) -> torch.Tensor:
    if isinstance(baseline, torch.Tensor):
        baseline_output = baseline
    else:
107
108
109
110
111
112
113
114
115
        baseline_output = baseline(
            a,
            w1,
            w2,
            score,
            topk,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
116
117
118
119
120
121
122
123
124
125
126

    # Pad the weight if moe padding is enabled
    if padding:
        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]

    if use_compile:
        moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(score, 0)

127
128
129
130
131
132
133
134
135
    test_output = moe_fn(
        a,
        w1,
        w2,
        score,
        topk,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
    )
136
137
138
139
140
141

    if use_cudagraph:
        test_output.fill_(0)
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
142
143
144
145
146
147
148
149
150
            test_output = moe_fn(
                a,
                w1,
                w2,
                score,
                topk,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
            )
151
152
153
154
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

155
    torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
156
157
158
159

    return baseline_output


160
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
161
162
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
163
@pytest.mark.parametrize("ep_size", EP_SIZE)
164
@pytest.mark.parametrize("dtype", [torch.bfloat16])
165
@pytest.mark.parametrize("padding", [True, False])
166
@pytest.mark.parametrize("chunk_size", [8192])
167
168
169
170
171
172
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
173
    ep_size: int,
174
    dtype: torch.dtype,
175
    padding: bool,
176
177
    chunk_size: int,
    monkeypatch,
178
):
179
180
181
182
183
184
185
186
    current_platform.seed_everything(7)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

    #
    # Setup test data
    #

bnellnm's avatar
bnellnm committed
187
188
189
190
    #
    # Setup test data
    #

191
192
193
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
194

195
    score = torch.randn((m, e), device="cuda", dtype=dtype)
196
197
198

    if ep_size > 1:
        local_e = e // ep_size
199
200
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
201
202
203
204
205
206
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

207
208
209
    #
    # Setup test functions
    #
210
    quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
211

212
    m_fused_moe_fn = modular_triton_fused_moe(quant_config)
213
214
215
216
217
218
219
220

    def m_fused_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        score: torch.Tensor,
        topk: int,
        global_num_experts: int = -1,
221
        expert_map: torch.Tensor | None = None,
222
223
    ) -> torch.Tensor:
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
224
225
226
227
228
229
230
231
232
        return m_fused_moe_fn(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

    fused_moe_fn = functools.partial(fused_moe, renormalize=False)

    #
    # Run tests
    #
    runner = functools.partial(
        run_moe_test,
        a=a,
        w1=w1,
        w2=w2,
        score=score,
        topk=topk,
        global_num_experts=e,
        expert_map=e_map,
        padding=padding,
    )
250

251
252
253
254
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False
255

256
    use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
257

258
259
    with set_current_vllm_config(vllm_config):
        baseline_output = runner(torch_moe, iterative_moe)
260
261
262
263
264
265
266
267
268
269
270
271
        runner(
            baseline_output,
            fused_moe_fn,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
        runner(
            baseline_output,
            m_fused_moe,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
272
273


274
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
275
276
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
277
@pytest.mark.parametrize("ep_size", EP_SIZE)
278
@pytest.mark.parametrize("dtype", [torch.bfloat16])
279
280
281
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
282
283
284
285
286
287
288
289
290
291
292
293
def test_fused_moe_wn16(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    ep_size: int,
    dtype: torch.dtype,
    group_size: int,
    has_zp: bool,
    weight_bits: int,
):
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
308
309
310
311
312
313
314
315
316
317
318
319
    w1_qweight = torch.empty(
        (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
    )
    w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
    w1_qzeros = torch.empty(
        (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
    )
    w2_qzeros = torch.empty(
        (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
    )
320
321
322
323

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
324
325
326
327
328
329
330
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w1,
                w1_ref,
                w1_qweight,
                w1_scales,
                w1_qzeros,
            )
331
        else:
332
333
334
335
336
337
338
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w2,
                w2_ref,
                w2_qweight,
                w2_scales,
                w2_qzeros,
            )
339
        weight, qweight, scales, qzeros = quantize_weights(
340
341
            w[expert_id].T, quant_type, group_size, has_zp, False
        )
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

358
359
    if ep_size > 1:
        local_e = e // ep_size
360
361
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
362
363
364
365
366
367
368
369
370
371
372
373
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

374
375
376
377
378
379
    if weight_bits == 4:
        quant_config_builder = int4_w4a16_moe_quant_config
    else:
        assert weight_bits == 8
        quant_config_builder = int8_w8a16_moe_quant_config

380
381
382
383
384
385
386
    quant_config = quant_config_builder(
        w1_scale=w1_scales,
        w2_scale=w2_scales,
        w1_zp=w1_qzeros if has_zp else None,
        w2_zp=w2_qzeros if has_zp else None,
        block_shape=[0, group_size],
    )
387

388
    with set_current_vllm_config(vllm_config):
389
390
391
392
393
394
395
396
397
398
399
400
        triton_output = fused_moe(
            a,
            w1_qweight,
            w2_qweight,
            score,
            topk,
            renormalize=False,
            global_num_experts=e,
            expert_map=e_map,
            quant_config=quant_config,
        )
        torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
401

402
403
404
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


405
@pytest.mark.parametrize("dtype", [torch.bfloat16])
406
@pytest.mark.parametrize("padding", [True, False])
407
@pytest.mark.parametrize(
408
409
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
410
@torch.inference_mode()
411
412
413
def test_mixtral_moe(
    dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
414
415
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
416

417
418
    # clear the cache before every test
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
419
420
421
        is_rocm_aiter_moe_enabled,
    )

422
    is_rocm_aiter_moe_enabled.cache_clear()
423
424
425
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

426
427
428
        if dtype == torch.float32:
            pytest.skip("AITER ROCm test skip for float32")

429
430
431
432
433
    monkeypatch.setenv("RANK", "0")
    monkeypatch.setenv("LOCAL_RANK", "0")
    monkeypatch.setenv("WORLD_SIZE", "1")
    monkeypatch.setenv("MASTER_ADDR", "localhost")
    monkeypatch.setenv("MASTER_PORT", "12345")
bnellnm's avatar
bnellnm committed
434
435
    init_distributed_environment()

436
    # Instantiate our and huggingface's MoE blocks
437
    vllm_config.compilation_config.static_forward_context = dict()
438
    with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        config = MixtralConfig()
        hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
        vllm_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            params_dtype=dtype,
            tp_size=1,
            dp_size=1,
        ).cuda()

        # Load the weights
        vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
        for i in range(config.num_local_experts):
454
455
456
457
            weights = (
                hf_moe.experts[i].w1.weight.data,
                hf_moe.experts[i].w3.weight.data,
            )
458
459
460
461
            vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
            vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data

        # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
462
        hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
463
464
        # vLLM uses 1D query [num_tokens, hidden_dim]
        vllm_inputs = hf_inputs.flatten(0, 1)
465

466
467
        # Pad the weight if moe padding is enabled
        if padding:
468
469
470
471
472
473
474
475
476
477
            vllm_moe.experts.w13_weight = Parameter(
                F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
                    ..., 0:-128
                ],
                requires_grad=False,
            )
            vllm_moe.experts.w2_weight = Parameter(
                F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
                requires_grad=False,
            )
478
            torch.cuda.synchronize()
479
480
481
482
483
            torch.cuda.empty_cache()

        # Run forward passes for both MoE blocks
        hf_states, _ = hf_moe.forward(hf_inputs)
        vllm_states = vllm_moe.forward(vllm_inputs)
484
485
486
487
488
489
490

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

491
    if use_rocm_aiter:
492
493
        # The values of rtol and atol are set based on the tests in ROCM AITER package.
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174
494
495
496
        torch.testing.assert_close(
            hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
        )
497
    else:
498
499
500
501
502
503
        torch.testing.assert_close(
            hf_states.flatten(0, 1),
            vllm_states,
            rtol=mixtral_moe_tol[dtype],
            atol=mixtral_moe_tol[dtype],
        )
504
505


506
507
def marlin_moe_generate_valid_test_cases():
    import itertools
508

509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    m_list = [1, 123, 666]
    n_list = [128, 1024]
    k_list = [256, 2048]
    e_list = [4, 12]
    topk_list = [2, 3]
    ep_size_list = [1, 4]
    dtype_list = [torch.half, torch.bfloat16]
    group_size_list = [-1, 16, 32, 128]
    act_order_list = [True, False]
    quant_type_list = [
        scalar_types.float4_e2m1f,
        scalar_types.float8_e4m3fn,
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.uint8b128,
    ]
    is_k_full_list = [True, False]

527
528
529
530
531
532
533
534
535
536
537
538
539
    all_combinations = itertools.product(
        m_list,
        n_list,
        k_list,
        e_list,
        topk_list,
        ep_size_list,
        dtype_list,
        group_size_list,
        act_order_list,
        quant_type_list,
        is_k_full_list,
    )
540

541
542
543
544
    def is_invalid(
        m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
    ):
        if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
545
            return False
546
547
548
549
550
        if quant_type == scalar_types.float4_e2m1f:
            if group_size not in [16, 32]:
                return False
            if dtype == torch.float16 and group_size == 32:
                return False
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        if quant_type != scalar_types.float4_e2m1f and group_size == 16:
            return False

        # Filter act_order
        if act_order:
            if group_size in (-1, k, n):
                return False
            if quant_type not in [scalar_types.uint4b8]:
                return False
        elif not is_k_full:
            return False

        return True

    cases = []
    for case in all_combinations:
        if is_invalid(*case):
            cases.append(case)
    return cases


572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
@dataclass
class MarlinMoEWeightData:
    w_ref: torch.Tensor
    qweight: torch.Tensor
    scales: torch.Tensor
    global_scale: torch.Tensor | None
    g_idx: torch.Tensor | None
    zeros: torch.Tensor | None
    sort_indices: torch.Tensor | None
    marlin_bias: torch.Tensor | None

    @staticmethod
    def make(
        w: torch.Tensor,
        quant_type: ScalarType,
        group_size: int,
        act_order: bool | None = None,
        bias: torch.Tensor | None = None,
    ) -> "MarlinMoEWeightData":
        assert w.ndim == 3
        has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
        k = w.shape[-1]

        w_ref_l: list[torch.Tensor] = []
        qweight_l: list[torch.Tensor] = []
        scales_l: list[torch.Tensor] = []
        global_scale_l: list[torch.Tensor] = []
        zeros_l: list[torch.Tensor] = []
        g_idx_l: list[torch.Tensor] = []
        sort_indices_l: list[torch.Tensor] = []
        bias_l: list[torch.Tensor] = []

        for i in range(w.shape[0]):
            if quant_type == scalar_types.float4_e2m1f:
                if group_size == 16:
                    w_ref, qweight, scales, global_scale = (
                        rand_marlin_weight_nvfp4_like(w[i], group_size)
                    )
                else:
                    w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
                        w[i], group_size
                    )
                    global_scale = None

                w_ref_l.append(w_ref.T)
                qweight_l.append(qweight)
                scales_l.append(scales)
                if global_scale is not None:
                    global_scale_l.append(global_scale)
            elif quant_type == scalar_types.float8_e4m3fn:
                w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
                w_ref_l.append(w_ref.T)
                qweight_l.append(qweight)
                scales_l.append(scales)
            elif has_zp:
                w_ref, qweight, scales, zeros = awq_marlin_quantize(
                    w[i].transpose(1, 0), quant_type, group_size
                )

                w_ref_l.append(w_ref.T)
                qweight_l.append(qweight)
                scales_l.append(scales)
                zeros_l.append(zeros)
            else:
                test_perm = torch.randperm(k)
                w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
                    w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
                )

                w_ref_l.append(w_ref.T)
                qweight_l.append(qweight)
                scales_l.append(scales)
                g_idx_l.append(g_idx)
                sort_indices_l.append(sort_indices)

            if bias is not None:
                bias_l.append(marlin_permute_bias(bias[i]))

        w_ref = stack_and_dev(w_ref_l)
        qweight = stack_and_dev(qweight_l).contiguous()
        scales = stack_and_dev(scales_l)
        global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
        g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
        zeros = stack_and_dev(zeros_l) if zeros_l else None
        sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
        marlin_bias = stack_and_dev(bias_l) if bias_l else None

        return MarlinMoEWeightData(
            w_ref=w_ref,
            qweight=qweight,
            scales=scales,
            global_scale=global_scale,
            g_idx=g_idx,
            zeros=zeros,
            sort_indices=sort_indices,
            marlin_bias=marlin_bias,
        )


671
@pytest.mark.flaky(reruns=2)
672
673
674
675
@pytest.mark.parametrize(
    ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
    marlin_moe_generate_valid_test_cases(),
)
676
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
677
678
679
680
681
682
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
683
684
    ep_size: int,
    dtype: torch.dtype,
685
686
    group_size: int,
    act_order: bool,
687
    quant_type: ScalarType,
688
    is_k_full: bool,
689
):
690
691
    torch.cuda.manual_seed(0)

692
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
693
694
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
695

696
697
698
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
699
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
700
701
702
703
704
705
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

706
707
708
    w1_data = MarlinMoEWeightData.make(
        w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
    )
709

710
711
712
    w2_data = MarlinMoEWeightData.make(
        w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
    )
713
714
715

    score = torch.randn((m, e), device="cuda", dtype=dtype)

716
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
717

718
    with set_current_vllm_config(vllm_config):
719
720
721
        torch_output = torch_moe(
            a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
        )
722

723
    marlin_output = fused_marlin_moe(
724
        a,
725
726
        w1_data.qweight,
        w2_data.qweight,
727
728
        None,
        None,
729
730
        w1_data.scales,
        w2_data.scales,
731
732
733
        score,
        topk_weights,
        topk_ids,
734
735
        global_num_experts=e,
        expert_map=e_map,
736
737
738
739
740
741
742
743
        global_scale1=w1_data.global_scale,
        global_scale2=w2_data.global_scale,
        g_idx1=w1_data.g_idx,
        g_idx2=w2_data.g_idx,
        sort_indices1=w1_data.sort_indices,
        sort_indices2=w2_data.sort_indices,
        w1_zeros=w1_data.zeros,
        w2_zeros=w2_data.zeros,
744
        quant_type_id=quant_type.id,
745
746
        is_k_full=is_k_full,
    )
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770

    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
    torch.cuda.manual_seed(0)

    e, topk = 32, 4
    n, k = 2048, 2048
    group_size = 128
    act_order = False
    is_k_full = True
    quant_type = scalar_types.uint4b8
    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
    b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10

771
772
773
774
775
776
777
    w1_data = MarlinMoEWeightData.make(
        w=w1,
        quant_type=quant_type,
        group_size=group_size,
        act_order=act_order,
        bias=b_bias1,
    )
778

779
780
781
782
783
784
785
    w2_data = MarlinMoEWeightData.make(
        w=w2,
        quant_type=quant_type,
        group_size=group_size,
        act_order=act_order,
        bias=b_bias2,
    )
786
787
788
789
790
791

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

    with set_current_vllm_config(vllm_config):
792
793
794
        torch_output = torch_moe(
            a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
        )
795

796
    marlin_output = fused_marlin_moe(
797
        a,
798
799
800
801
802
803
        w1_data.qweight,
        w2_data.qweight,
        w1_data.marlin_bias,
        w2_data.marlin_bias,
        w1_data.scales,
        w2_data.scales,
804
805
806
807
808
        score,
        topk_weights,
        topk_ids,
        global_num_experts=e,
        expert_map=None,
809
810
811
812
813
814
815
816
        global_scale1=w1_data.global_scale,
        global_scale2=w2_data.global_scale,
        g_idx1=w1_data.g_idx,
        g_idx2=w2_data.g_idx,
        sort_indices1=w1_data.sort_indices,
        sort_indices2=w2_data.sort_indices,
        w1_zeros=w1_data.zeros,
        w2_zeros=w2_data.zeros,
817
        quant_type_id=quant_type.id,
818
819
        is_k_full=is_k_full,
    )
820

821
    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
822
823
824
825
826


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
827
    topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
828
829

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
830
831
832
    sorted_ids = torch.empty(
        (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
    )
833
834
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    expert_ids = torch.empty(
        (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
    )
    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

    opcheck(
        torch.ops._moe_C.moe_align_block_size,
        (
            topk_ids,
            num_experts,
            block_size,
            sorted_ids,
            expert_ids,
            num_tokens_post_pad,
        ),
    )
bnellnm's avatar
bnellnm committed
851
852


853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
def test_batched_moe_align_block_size_opcheck():
    max_tokens_per_batch = 512
    num_experts = 4
    block_size = 16

    expert_num_tokens = torch.randint(
        low=0,
        high=max_tokens_per_batch,
        size=(num_experts,),
        dtype=torch.int32,
        device="cuda",
    )

    max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
    sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")

    assert max_num_tokens_padded % block_size == 0
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")

    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")

    opcheck(
        torch.ops._moe_C.batched_moe_align_block_size,
        (
            max_tokens_per_batch,
            block_size,
            expert_num_tokens,
            sorted_ids,
            expert_ids,
            num_tokens_post_pad,
        ),
    )


888
@pytest.mark.parametrize("m", [1, 33, 64, 222])
bnellnm's avatar
bnellnm committed
889
890
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
891
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
bnellnm's avatar
bnellnm committed
892
893
894
895
896
897
898
899
900
901
902
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
    input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
    actual = torch.empty((m, k), device="cuda", dtype=dtype)

    expected = input.sum(dim=1)
    torch.ops._moe_C.moe_sum(input, actual)

    torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)

    opcheck(torch.ops._moe_C.moe_sum, (input, actual))
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971


@pytest.mark.parametrize("m", [1, 33])
@pytest.mark.parametrize("n,k", [(128, 128)])
@pytest.mark.parametrize("e", [8])
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
    from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE

    device = "cpu"
    torch.manual_seed(7)

    a = torch.randn((m, k), device=device, dtype=dtype) / 10
    w13 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
    router_logits = torch.randn((m, e), device=device, dtype=dtype)

    b1 = b2 = None
    if with_bias:
        b1 = torch.randn((e, 2 * n), device=device, dtype=dtype) / 10
        b2 = torch.randn((e, k), device=device, dtype=dtype) / 10

    ref = (
        torch_moe(a, w13, w2, router_logits, topk, b1, b2)
        if with_bias
        else torch_moe(a, w13, w2, router_logits, topk)
    )

    class _Dummy(torch.nn.Module):
        def __init__(self, w13, w2, b1=None, b2=None):
            super().__init__()
            self.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
            self.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
            if b1 is not None:
                self.w13_bias = torch.nn.Parameter(b1, requires_grad=False)
            if b2 is not None:
                self.w2_bias = torch.nn.Parameter(b2, requires_grad=False)

    layer = _Dummy(w13, w2, b1, b2).to(dtype)
    fused = CPUFusedMOE(layer)
    out = fused(
        layer=layer,
        x=a,
        use_grouped_topk=False,
        top_k=topk,
        router_logits=router_logits,
        renormalize=False,
        global_num_experts=e,
        expert_map=None,
        custom_routing_function=None,
        scoring_func="softmax",
        routed_scaling_factor=1.0,
        e_score_correction_bias=None,
        apply_router_weight_on_input=False,
        activation=activation,
    )

    # Tolerances: fp32 tight; bf16 looser (esp. with bias)
    if dtype == torch.float32:
        atol = 1e-3
    elif with_bias:
        atol = 8e-2
    else:
        atol = 5e-2
    torch.testing.assert_close(out, ref, atol=atol, rtol=0)
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139


@pytest.mark.parametrize("m", [16, 32, 64])
@pytest.mark.parametrize("n", [128])
@pytest.mark.parametrize("k", [128])
@pytest.mark.parametrize("e", [8, 12, 16, 32])
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_batched_fused_marlin_moe(
    m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
):
    print(
        f"testing m={m}, n={n}, k={k}, e={e}, "
        f"topk={topk}, "
        f"max_tokens_per_batch={max_tokens_per_batch}"
    )
    torch.cuda.manual_seed(0)

    dtype = torch.bfloat16
    quant_dtype = scalar_types.float4_e2m1f
    group_size = 32

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20

    w1_data = MarlinMoEWeightData.make(
        w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
    )
    w2_data = MarlinMoEWeightData.make(
        w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
    )

    score = torch.randn((m, e), device="cuda", dtype=dtype)
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

    class BatchedRun:
        @staticmethod
        def _make_expert_num_tokens_cpu(
            e: int,  # num_experts
            topk_ids_cpu: torch.Tensor,
        ) -> torch.Tensor:
            expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
            for topk_id in torch.flatten(topk_ids_cpu):
                expert_num_tokens_cpu[topk_id] += 1
            return expert_num_tokens_cpu

        def __init__(
            self,
            max_tokens_per_batch: int,
            num_experts: int,
            _topk_ids: torch.Tensor,
            _topk_weights: torch.Tensor,
        ):
            self.max_tokens_per_batch = max_tokens_per_batch
            self.e = num_experts
            self.topk_ids_cpu = _topk_ids.to("cpu")
            self.topk_weights_cpu = _topk_weights.to("cpu")
            self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
                self.e, self.topk_ids_cpu
            )

        def is_valid(self):
            """
            Return True only if the input can be represented in a Batched
            format.
            """
            return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)

        def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
            hidden_states_cpu = hidden_states.to("cpu")
            K = hidden_states_cpu.size(1)
            batched_hidden_states_cpu = torch.empty(
                (e, max_tokens_per_batch, K),
                dtype=hidden_states_cpu.dtype,
                device="cpu",
            )

            counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
            for t_idx, token in enumerate(hidden_states_cpu):
                for topk_id in self.topk_ids_cpu[t_idx]:
                    pos_in_batch = counter_cpu[topk_id]
                    batched_hidden_states_cpu[topk_id, pos_in_batch] = token
                    counter_cpu[topk_id] += 1
            assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
            return batched_hidden_states_cpu.to("cuda")

        def _gather(
            self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
        ) -> torch.Tensor:
            batched_outputs_cpu = batched_outputs.to("cpu")
            gather_outputs_cpu = torch.zeros_like(gather_outputs)

            counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
            md = gather_outputs_cpu.size(0)
            for t_idx in range(md):
                token = None
                for topk_id, topk_weight in zip(
                    self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
                ):
                    pos_in_batch = counter_cpu[topk_id]
                    t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
                    if token is None:
                        token = t
                    else:
                        token += t
                    counter_cpu[topk_id] += 1
                assert token is not None
                gather_outputs_cpu[t_idx] = token
            gather_outputs.copy_(gather_outputs_cpu)
            return gather_outputs

        def run(
            self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
        ) -> torch.Tensor:
            assert hidden_states.ndim == 2
            assert self.is_valid()

            batched_hidden_states = self._scatter(hidden_states)

            kwargs = fused_marlin_moe_kwargs | {
                "hidden_states": batched_hidden_states,
                "expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
            }
            batched_outputs = batched_fused_marlin_moe(**kwargs)

            output = torch.zeros_like(hidden_states)
            output = self._gather(batched_outputs, output)
            return output

    kwargs = {
        "w1": w1_data.qweight,
        "w2": w2_data.qweight,
        "bias1": None,
        "bias2": None,
        "w1_scale": w1_data.scales,
        "w2_scale": w2_data.scales,
        "gating_output": score,
        "global_num_experts": e,
        "expert_map": None,
        "global_scale1": w1_data.global_scale,
        "global_scale2": w2_data.global_scale,
        "g_idx1": w1_data.g_idx,
        "g_idx2": w2_data.g_idx,
        "sort_indices1": w1_data.sort_indices,
        "sort_indices2": w2_data.sort_indices,
        "w1_zeros": w1_data.zeros,
        "w2_zeros": w2_data.zeros,
        "quant_type_id": quant_dtype.id,
        "is_k_full": True,
    }

    # Reference
    fused_marlin_moe_kwargs = kwargs | {
        "hidden_states": a,
        "topk_ids": topk_ids,
        "topk_weights": topk_weights,
    }
    ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)

    # Batched
    br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
    if not br.is_valid():
        pytest.skip("Cannot represent data in Batched Format.")
    marlin_output = br.run(a, kwargs)

    torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)