test_moe.py 28.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
7

8
9
10
import functools
from typing import Callable, Optional, Union

11
12
import pytest
import torch
13
14
from torch.nn import Parameter
from torch.nn import functional as F
15
16
17
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

18
import vllm.model_executor.layers.fused_moe  # noqa
19
from tests.kernels.moe.utils import fused_moe
20
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
21
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
22
from vllm.distributed.parallel_state import init_distributed_environment
23
from vllm.forward_context import set_forward_context
24
from vllm.model_executor.layers.fused_moe.config import (
25
26
27
28
    FUSED_MOE_UNQUANTIZED_CONFIG,
    int4_w4a16_moe_quant_config,
    int8_w8a16_moe_quant_config,
)
29
from vllm.model_executor.layers.fused_moe.fused_moe import (
30
31
32
    fused_topk,
    modular_triton_fused_moe,
)
33
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
34
35
    fused_moe as iterative_moe,
)
36
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
37
38
    marlin_permute_bias,
)
39
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
40
41
42
    rand_marlin_weight_mxfp4_like,
    rand_marlin_weight_nvfp4_like,
)
43
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
44
45
    marlin_quant_fp8_torch,
)
46
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
47
48
49
50
    awq_marlin_quantize,
    marlin_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
51
from vllm.model_executor.models.mixtral import MixtralMoE
52
from vllm.platforms import current_platform
53
from vllm.scalar_type import ScalarType, scalar_types
54

55
NUM_EXPERTS = [8, 64, 192]
56
EP_SIZE = [1, 4]
57
TOP_KS = [2, 6]
58

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
FUSED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (1, 2048, 128),
    (33, 2048, 128),
    (222, 1024, 1024),
    (32768, 128, 128),
    (32768, 2048, 511),
    (40000, 1024, 1024),
]

FUSED_MOE_WN16_MNK_FACTORS = [
    (1, 128, 128),
    (1, 1024, 1024),
    (32, 2048, 128),
    (32, 1024, 1024),
    (222, 2048, 1024),
]

77
78
79
80
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def run_moe_test(
    baseline: Union[Callable, torch.Tensor],
    moe_fn: Callable,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    padding: bool = False,
    use_compile: bool = False,
    use_cudagraph: bool = False,
    atol: float = 2e-2,
    rtol: float = 0,
) -> torch.Tensor:
    if isinstance(baseline, torch.Tensor):
        baseline_output = baseline
    else:
101
102
103
104
105
106
107
108
109
        baseline_output = baseline(
            a,
            w1,
            w2,
            score,
            topk,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
110
111
112
113
114
115
116
117
118
119
120

    # Pad the weight if moe padding is enabled
    if padding:
        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]

    if use_compile:
        moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(score, 0)

121
122
123
124
125
126
127
128
129
    test_output = moe_fn(
        a,
        w1,
        w2,
        score,
        topk,
        global_num_experts=global_num_experts,
        expert_map=expert_map,
    )
130
131
132
133
134
135

    if use_cudagraph:
        test_output.fill_(0)
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
136
137
138
139
140
141
142
143
144
            test_output = moe_fn(
                a,
                w1,
                w2,
                score,
                topk,
                global_num_experts=global_num_experts,
                expert_map=expert_map,
            )
145
146
147
148
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

149
    torch.testing.assert_close(test_output, baseline_output, atol=atol, rtol=rtol)
150
151
152
153

    return baseline_output


154
@pytest.mark.parametrize("m,n,k", FUSED_MOE_MNK_FACTORS)
155
156
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
157
@pytest.mark.parametrize("ep_size", EP_SIZE)
158
@pytest.mark.parametrize("dtype", [torch.bfloat16])
159
@pytest.mark.parametrize("padding", [True, False])
160
@pytest.mark.parametrize("chunk_size", [8192])
161
162
163
164
165
166
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
167
    ep_size: int,
168
    dtype: torch.dtype,
169
    padding: bool,
170
171
    chunk_size: int,
    monkeypatch,
172
):
173
174
175
176
177
178
179
180
    current_platform.seed_everything(7)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

    #
    # Setup test data
    #

bnellnm's avatar
bnellnm committed
181
182
183
184
    #
    # Setup test data
    #

185
186
187
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
188

189
    score = torch.randn((m, e), device="cuda", dtype=dtype)
190
191
192

    if ep_size > 1:
        local_e = e // ep_size
193
194
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
195
196
197
198
199
200
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

201
202
203
    #
    # Setup test functions
    #
204
    quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
205

206
    m_fused_moe_fn = modular_triton_fused_moe(quant_config)
207
208
209
210
211
212
213
214
215
216
217

    def m_fused_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        score: torch.Tensor,
        topk: int,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
218
219
220
221
222
223
224
225
226
        return m_fused_moe_fn(
            a,
            w1,
            w2,
            topk_weights,
            topk_ids,
            global_num_experts=global_num_experts,
            expert_map=expert_map,
        )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    fused_moe_fn = functools.partial(fused_moe, renormalize=False)

    #
    # Run tests
    #
    runner = functools.partial(
        run_moe_test,
        a=a,
        w1=w1,
        w2=w2,
        score=score,
        topk=topk,
        global_num_experts=e,
        expert_map=e_map,
        padding=padding,
    )
244

245
246
247
248
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False
249

250
    use_cudagraph = n >= 1024 and k >= 1024 and current_platform.is_cuda_alike()
251

252
253
    with set_current_vllm_config(vllm_config):
        baseline_output = runner(torch_moe, iterative_moe)
254
255
256
257
258
259
260
261
262
263
264
265
        runner(
            baseline_output,
            fused_moe_fn,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
        runner(
            baseline_output,
            m_fused_moe,
            use_compile=use_compile,
            use_cudagraph=use_cudagraph,
        )
266
267


268
@pytest.mark.parametrize("m,n,k", FUSED_MOE_WN16_MNK_FACTORS)
269
270
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
271
@pytest.mark.parametrize("ep_size", EP_SIZE)
272
@pytest.mark.parametrize("dtype", [torch.bfloat16])
273
274
275
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
276
277
278
279
280
281
282
283
284
285
286
287
def test_fused_moe_wn16(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    ep_size: int,
    dtype: torch.dtype,
    group_size: int,
    has_zp: bool,
    weight_bits: int,
):
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
302
303
304
305
306
307
308
309
310
311
312
313
    w1_qweight = torch.empty(
        (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
    )
    w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
    w1_qzeros = torch.empty(
        (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
    )
    w2_qzeros = torch.empty(
        (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
    )
314
315
316
317

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
318
319
320
321
322
323
324
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w1,
                w1_ref,
                w1_qweight,
                w1_scales,
                w1_qzeros,
            )
325
        else:
326
327
328
329
330
331
332
            w, w_ref, w_qweight, w_scales, w_qzeros = (
                w2,
                w2_ref,
                w2_qweight,
                w2_scales,
                w2_qzeros,
            )
333
        weight, qweight, scales, qzeros = quantize_weights(
334
335
            w[expert_id].T, quant_type, group_size, has_zp, False
        )
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

352
353
    if ep_size > 1:
        local_e = e // ep_size
354
355
        e_ids = torch.randint(0, e, (local_e,), device="cuda", dtype=torch.int32)
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
356
357
358
359
360
361
362
363
364
365
366
367
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

368
369
370
371
372
373
    if weight_bits == 4:
        quant_config_builder = int4_w4a16_moe_quant_config
    else:
        assert weight_bits == 8
        quant_config_builder = int8_w8a16_moe_quant_config

374
375
376
377
378
379
380
    quant_config = quant_config_builder(
        w1_scale=w1_scales,
        w2_scale=w2_scales,
        w1_zp=w1_qzeros if has_zp else None,
        w2_zp=w2_qzeros if has_zp else None,
        block_shape=[0, group_size],
    )
381

382
    with set_current_vllm_config(vllm_config):
383
384
385
386
387
388
389
390
391
392
393
394
        triton_output = fused_moe(
            a,
            w1_qweight,
            w2_qweight,
            score,
            topk,
            renormalize=False,
            global_num_experts=e,
            expert_map=e_map,
            quant_config=quant_config,
        )
        torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, expert_map=e_map)
395

396
397
398
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


399
@pytest.mark.parametrize("dtype", [torch.bfloat16])
400
@pytest.mark.parametrize("padding", [True, False])
401
@pytest.mark.parametrize(
402
403
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]
)
404
@torch.inference_mode()
405
406
407
def test_mixtral_moe(
    dist_init, dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, monkeypatch
):
408
409
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
410

411
412
    # clear the cache before every test
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
413
414
415
        is_rocm_aiter_moe_enabled,
    )

416
    is_rocm_aiter_moe_enabled.cache_clear()
417
418
419
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

420
421
422
        if dtype == torch.float32:
            pytest.skip("AITER ROCm test skip for float32")

423
424
425
426
427
    monkeypatch.setenv("RANK", "0")
    monkeypatch.setenv("LOCAL_RANK", "0")
    monkeypatch.setenv("WORLD_SIZE", "1")
    monkeypatch.setenv("MASTER_ADDR", "localhost")
    monkeypatch.setenv("MASTER_PORT", "12345")
bnellnm's avatar
bnellnm committed
428
429
    init_distributed_environment()

430
    # Instantiate our and huggingface's MoE blocks
431
    vllm_config.compilation_config.static_forward_context = dict()
432
    with set_current_vllm_config(vllm_config), set_forward_context(None, vllm_config):
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        config = MixtralConfig()
        hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
        vllm_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            params_dtype=dtype,
            tp_size=1,
            dp_size=1,
        ).cuda()

        # Load the weights
        vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
        for i in range(config.num_local_experts):
448
449
450
451
            weights = (
                hf_moe.experts[i].w1.weight.data,
                hf_moe.experts[i].w3.weight.data,
            )
452
453
454
455
            vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
            vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data

        # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
456
        hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
457
458
        # vLLM uses 1D query [num_tokens, hidden_dim]
        vllm_inputs = hf_inputs.flatten(0, 1)
459

460
461
        # Pad the weight if moe padding is enabled
        if padding:
462
463
464
465
466
467
468
469
470
471
            vllm_moe.experts.w13_weight = Parameter(
                F.pad(vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[
                    ..., 0:-128
                ],
                requires_grad=False,
            )
            vllm_moe.experts.w2_weight = Parameter(
                F.pad(vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
                requires_grad=False,
            )
472
            torch.cuda.synchronize()
473
474
475
476
477
            torch.cuda.empty_cache()

        # Run forward passes for both MoE blocks
        hf_states, _ = hf_moe.forward(hf_inputs)
        vllm_states = vllm_moe.forward(vllm_inputs)
478
479
480
481
482
483
484

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

485
    if use_rocm_aiter:
486
487
        # The values of rtol and atol are set based on the tests in ROCM AITER package.
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174
488
489
490
        torch.testing.assert_close(
            hf_states.flatten(0, 1), vllm_states, rtol=0.01, atol=100
        )
491
    else:
492
493
494
495
496
497
        torch.testing.assert_close(
            hf_states.flatten(0, 1),
            vllm_states,
            rtol=mixtral_moe_tol[dtype],
            atol=mixtral_moe_tol[dtype],
        )
498
499


500
501
def marlin_moe_generate_valid_test_cases():
    import itertools
502

503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    m_list = [1, 123, 666]
    n_list = [128, 1024]
    k_list = [256, 2048]
    e_list = [4, 12]
    topk_list = [2, 3]
    ep_size_list = [1, 4]
    dtype_list = [torch.half, torch.bfloat16]
    group_size_list = [-1, 16, 32, 128]
    act_order_list = [True, False]
    quant_type_list = [
        scalar_types.float4_e2m1f,
        scalar_types.float8_e4m3fn,
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.uint8b128,
    ]
    is_k_full_list = [True, False]

521
522
523
524
525
526
527
528
529
530
531
532
533
    all_combinations = itertools.product(
        m_list,
        n_list,
        k_list,
        e_list,
        topk_list,
        ep_size_list,
        dtype_list,
        group_size_list,
        act_order_list,
        quant_type_list,
        is_k_full_list,
    )
534

535
536
537
538
    def is_invalid(
        m, n, k, e, topk, ep_size, dtype, group_size, act_order, quant_type, is_k_full
    ):
        if quant_type == scalar_types.float8_e4m3fn and group_size not in [-1, 128]:
539
            return False
540
541
542
543
544
        if quant_type == scalar_types.float4_e2m1f:
            if group_size not in [16, 32]:
                return False
            if dtype == torch.float16 and group_size == 32:
                return False
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        if quant_type != scalar_types.float4_e2m1f and group_size == 16:
            return False

        # Filter act_order
        if act_order:
            if group_size in (-1, k, n):
                return False
            if quant_type not in [scalar_types.uint4b8]:
                return False
        elif not is_k_full:
            return False

        return True

    cases = []
    for case in all_combinations:
        if is_invalid(*case):
            cases.append(case)
    return cases


566
@pytest.mark.flaky(reruns=2)
567
568
569
570
@pytest.mark.parametrize(
    ("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
    marlin_moe_generate_valid_test_cases(),
)
571
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
572
573
574
575
576
577
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
578
579
    ep_size: int,
    dtype: torch.dtype,
580
581
    group_size: int,
    act_order: bool,
582
    quant_type: ScalarType,
583
    is_k_full: bool,
584
):
585
586
587
    torch.cuda.manual_seed(0)
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]

588
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
589
590
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
591

592
593
594
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
595
        e_map = torch.full((e,), -1, device="cuda", dtype=torch.int32)
596
597
598
599
600
601
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

602
603
604
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
605
    global_scale1_l = []
606
    zeros1_l = []
607
608
609
610
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
611
        if quant_type == scalar_types.float4_e2m1f:
612
            if group_size == 16:
613
                w_ref1, qweight1, scales1, global_scale1 = (
614
                    rand_marlin_weight_nvfp4_like(w1[i], group_size)
615
                )
616
            else:
617
618
619
                w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
                    w1[i], group_size
                )
620
                global_scale1 = None
621
622
623
624

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
625
626
            if global_scale1 is not None:
                global_scale1_l.append(global_scale1)
627
        elif quant_type == scalar_types.float8_e4m3fn:
628
            w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
629
630
631
632
            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
        elif has_zp:
633
            w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
634
635
                w1[i].transpose(1, 0), quant_type, group_size
            )
636
637
638
639
640

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            zeros1_l.append(zeros1)
641
        else:
642
            test_perm = torch.randperm(k)
643
644
645
            w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
                w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
            )
646
647
648
649
650
651

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            g_idx1_l.append(g_idx1)
            sort_indices1_l.append(sort_indices1)
652
653
654
655

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
656
    global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
657
658
659
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
660
661
662
663

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
664
    global_scale2_l = []
665
    zeros2_l = []
666
667
668
669
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
670
        if quant_type == scalar_types.float4_e2m1f:
671
            if group_size == 16:
672
                w_ref2, qweight2, scales2, global_scale2 = (
673
                    rand_marlin_weight_nvfp4_like(w2[i], group_size)
674
                )
675
            else:
676
677
678
                w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
                    w2[i], group_size
                )
679
                global_scale2 = None
680
681
682
683

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
684
685
            if global_scale2 is not None:
                global_scale2_l.append(global_scale2)
686
        elif quant_type == scalar_types.float8_e4m3fn:
687
            w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
688
689
690
691
            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
        elif has_zp:
692
            w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
693
694
                w2[i].transpose(1, 0), quant_type, group_size
            )
695
696
697
698
699

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            zeros2_l.append(zeros2)
700
        else:
701
            test_perm = torch.randperm(n)
702
703
704
            w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
                w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
            )
705
706
707
708
709
710

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            g_idx2_l.append(g_idx2)
            sort_indices2_l.append(sort_indices2)
711
712
713
714

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
715
    global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
716
717
718
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
719
720
721

    score = torch.randn((m, e), device="cuda", dtype=dtype)

722
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
723

724
    with set_current_vllm_config(vllm_config):
725
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
726

727
    marlin_output = torch.ops.vllm.fused_marlin_moe(
728
729
730
        a,
        qweight1,
        qweight2,
731
732
        None,
        None,
733
734
        scales1,
        scales2,
735
736
737
        score,
        topk_weights,
        topk_ids,
738
739
        global_num_experts=e,
        expert_map=e_map,
740
741
        global_scale1=global_scale1,
        global_scale2=global_scale2,
742
743
        g_idx1=g_idx1,
        g_idx2=g_idx2,
744
745
746
747
748
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
        w1_zeros=zeros1,
        w2_zeros=zeros2,
        quant_type_id=quant_type.id,
749
750
        is_k_full=is_k_full,
    )
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783

    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
    torch.cuda.manual_seed(0)

    e, topk = 32, 4
    n, k = 2048, 2048
    group_size = 128
    act_order = False
    is_k_full = True
    quant_type = scalar_types.uint4b8
    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
    b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10

    b_bias1_l = []
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
        test_perm = torch.randperm(k)
784
785
786
        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
            w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
        )
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812

        w_ref1_l.append(w_ref1.T)
        qweight1_l.append(qweight1)
        scales1_l.append(scales1)
        g_idx1_l.append(g_idx1)
        sort_indices1_l.append(sort_indices1)
        b_bias1_l.append(marlin_permute_bias(b_bias1[i]))

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
    global_scale1 = None
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
    marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None

    b_bias2_l = []
    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
        test_perm = torch.randperm(n)
813
814
815
        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
            w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
        )
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837

        w_ref2_l.append(w_ref2.T)
        qweight2_l.append(qweight2)
        scales2_l.append(scales2)
        g_idx2_l.append(g_idx2)
        sort_indices2_l.append(sort_indices2)
        b_bias2_l.append(marlin_permute_bias(b_bias2[i]))

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
    global_scale2 = None
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
    marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

    with set_current_vllm_config(vllm_config):
838
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856

    marlin_output = torch.ops.vllm.fused_marlin_moe(
        a,
        qweight1,
        qweight2,
        marlin_bias1,
        marlin_bias2,
        scales1,
        scales2,
        score,
        topk_weights,
        topk_ids,
        global_num_experts=e,
        expert_map=None,
        global_scale1=global_scale1,
        global_scale2=global_scale2,
        g_idx1=g_idx1,
        g_idx2=g_idx2,
857
858
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
859
860
        w1_zeros=zeros1,
        w2_zeros=zeros2,
861
        quant_type_id=quant_type.id,
862
863
        is_k_full=is_k_full,
    )
864

865
    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
866
867
868
869
870


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
871
    topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda")
872
873

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
874
875
876
    sorted_ids = torch.empty(
        (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
    )
877
878
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
    expert_ids = torch.empty(
        (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
    )
    num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)

    opcheck(
        torch.ops._moe_C.moe_align_block_size,
        (
            topk_ids,
            num_experts,
            block_size,
            sorted_ids,
            expert_ids,
            num_tokens_post_pad,
        ),
    )
bnellnm's avatar
bnellnm committed
895
896


897
@pytest.mark.parametrize("m", [1, 33, 64, 222])
bnellnm's avatar
bnellnm committed
898
899
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
900
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
bnellnm's avatar
bnellnm committed
901
902
903
904
905
906
907
908
909
910
911
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
    input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
    actual = torch.empty((m, k), device="cuda", dtype=dtype)

    expected = input.sum(dim=1)
    torch.ops._moe_C.moe_sum(input, actual)

    torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)

    opcheck(torch.ops._moe_C.moe_sum, (input, actual))