test_moe.py 30.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
7
8
9
import functools
from typing import Callable, Optional, Union

10
11
import pytest
import torch
12
13
from torch.nn import Parameter
from torch.nn import functional as F
14
15
16
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

17
import vllm.model_executor.layers.fused_moe  # noqa
18
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
19
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
20
from vllm.distributed.parallel_state import init_distributed_environment
21
from vllm.forward_context import set_forward_context
22
from vllm.model_executor.layers.fused_moe import fused_moe
23
24
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_topk, modular_triton_fused_moe)
25
26
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
    fused_moe as iterative_moe)
27
28
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
    marlin_permute_bias)
29
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
30
    rand_marlin_weight_mxfp4_like, rand_marlin_weight_nvfp4_like)
31
32
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
    marlin_quant_fp8_torch)
33
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
34
    awq_marlin_quantize, marlin_quantize)
35
36
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    quantize_weights)
37
from vllm.model_executor.models.mixtral import MixtralMoE
38
from vllm.platforms import current_platform
39
from vllm.scalar_type import ScalarType, scalar_types
40

41
NUM_EXPERTS = [8, 64, 192]
42
EP_SIZE = [1, 4]
43
TOP_KS = [2, 6]
44

45
46
47
48
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def run_moe_test(
    baseline: Union[Callable, torch.Tensor],
    moe_fn: Callable,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
    global_num_experts: int = -1,
    expert_map: Optional[torch.Tensor] = None,
    padding: bool = False,
    use_compile: bool = False,
    use_cudagraph: bool = False,
    atol: float = 2e-2,
    rtol: float = 0,
) -> torch.Tensor:
    if isinstance(baseline, torch.Tensor):
        baseline_output = baseline
    else:
        baseline_output = baseline(a,
                                   w1,
                                   w2,
                                   score,
                                   topk,
                                   global_num_experts=global_num_experts,
                                   expert_map=expert_map)

    # Pad the weight if moe padding is enabled
    if padding:
        w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
        w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]

    if use_compile:
        moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
        torch._dynamo.mark_dynamic(a, 0)
        torch._dynamo.mark_dynamic(score, 0)

    test_output = moe_fn(a,
                         w1,
                         w2,
                         score,
                         topk,
                         global_num_experts=global_num_experts,
                         expert_map=expert_map)

    if use_cudagraph:
        test_output.fill_(0)
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
            test_output = moe_fn(a,
                                 w1,
                                 w2,
                                 score,
                                 topk,
                                 global_num_experts=global_num_experts,
                                 expert_map=expert_map)
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

    torch.testing.assert_close(test_output,
                               baseline_output,
                               atol=atol,
                               rtol=rtol)

    return baseline_output


@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
120
@pytest.mark.parametrize("n", [128, 1024, 2048])
121
@pytest.mark.parametrize("k", [128, 511, 1024])
122
123
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
124
@pytest.mark.parametrize("ep_size", EP_SIZE)
125
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
126
@pytest.mark.parametrize("padding", [True, False])
127
@pytest.mark.parametrize("chunk_size", [8192])
128
129
130
131
132
133
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
134
    ep_size: int,
135
    dtype: torch.dtype,
136
    padding: bool,
137
138
    chunk_size: int,
    monkeypatch,
139
):
140
141
142
143
144
145
146
147
    current_platform.seed_everything(7)

    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

    #
    # Setup test data
    #

bnellnm's avatar
bnellnm committed
148
149
150
151
    #
    # Setup test data
    #

152
153
154
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
155

156
    score = torch.randn((m, e), device="cuda", dtype=dtype)
157
158
159
160
161
162
163
164
165
166
167
168
169
170

    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

171
172
173
174
175
176
177
178
    #
    # Setup test functions
    #

    m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
                                              use_int8_w8a8=False,
                                              use_int8_w8a16=False,
                                              use_int4_w4a16=False,
179
                                              use_mxfp4_w4a4=False,
bnellnm's avatar
bnellnm committed
180
                                              per_act_token_quant=False,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                                              block_shape=None)

    def m_fused_moe(
        a: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        score: torch.Tensor,
        topk: int,
        global_num_experts: int = -1,
        expert_map: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
        return m_fused_moe_fn(a,
                              w1,
                              w2,
                              topk_weights,
                              topk_ids,
                              global_num_experts=global_num_experts,
                              expert_map=expert_map)

    fused_moe_fn = functools.partial(fused_moe, renormalize=False)

    #
    # Run tests
    #
    runner = functools.partial(
        run_moe_test,
        a=a,
        w1=w1,
        w2=w2,
        score=score,
        topk=topk,
        global_num_experts=e,
        expert_map=e_map,
        padding=padding,
    )
217

218
219
220
221
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
    use_compile = False
222

223
224
    use_cudagraph = (n >= 1024 and k >= 1024
                     and current_platform.is_cuda_alike())
225

226
227
228
229
230
231
232
233
234
235
    with set_current_vllm_config(vllm_config):
        baseline_output = runner(torch_moe, iterative_moe)
        runner(baseline_output,
               fused_moe_fn,
               use_compile=use_compile,
               use_cudagraph=use_cudagraph)
        runner(baseline_output,
               m_fused_moe,
               use_compile=use_compile,
               use_cudagraph=use_cudagraph)
236
237


238
239
240
241
242
@pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
243
@pytest.mark.parametrize("ep_size", EP_SIZE)
244
245
246
247
248
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
249
250
                        ep_size: int, dtype: torch.dtype, group_size: int,
                        has_zp: bool, weight_bits: int):
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    if weight_bits == 4:
        pack_factor = 2
        quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
    elif weight_bits == 8:
        pack_factor = 1
        quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128

    w1_ref = w1.clone()
    w2_ref = w2.clone()
    w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w2_qweight = torch.empty((e, k, n // pack_factor),
                             device="cuda",
                             dtype=torch.uint8)
    w1_scales = torch.empty((e, 2 * n, k // group_size),
                            device="cuda",
                            dtype=dtype)
    w2_scales = torch.empty((e, k, n // group_size),
                            device="cuda",
                            dtype=dtype)
    w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
                            device="cuda",
                            dtype=torch.uint8)
    w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
                            device="cuda",
                            dtype=torch.uint8)

    for i in range(e * 2):
        expert_id = i % e
        if i // e == 0:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
        else:
            w, w_ref, w_qweight, w_scales, w_qzeros = \
                w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
        weight, qweight, scales, qzeros = quantize_weights(
            w[expert_id].T, quant_type, group_size, has_zp, False)
        weight = weight.T
        qweight = qweight.T.contiguous().to(torch.uint8)
        scales = scales.T
        if has_zp:
            qzeros = qzeros.T.contiguous().to(torch.uint8)
        if weight_bits == 4:
            qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
            if has_zp:
                qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]

        w_ref[expert_id] = weight
        w_qweight[expert_id] = qweight
        w_scales[expert_id] = scales
        if has_zp:
            w_qzeros[expert_id] = qzeros

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randint(0,
                              e, (local_e, ),
                              device="cuda",
                              dtype=torch.int32)
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1_ref = w1_ref[e_ids]
        w2_ref = w2_ref[e_ids]
        w1_qweight = w1_qweight[e_ids]
        w2_qweight = w2_qweight[e_ids]
        w1_scales = w1_scales[e_ids]
        w2_scales = w2_scales[e_ids]
        w1_qzeros = w1_qzeros[e_ids]
        w2_qzeros = w2_qzeros[e_ids]
    else:
        e_map = None

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    with set_current_vllm_config(vllm_config):
        triton_output = fused_moe(a,
                                  w1_qweight,
                                  w2_qweight,
                                  score,
                                  topk,
                                  renormalize=False,
                                  use_int4_w4a16=weight_bits == 4,
                                  use_int8_w8a16=weight_bits == 8,
                                  global_num_experts=e,
                                  expert_map=e_map,
                                  w1_scale=w1_scales,
                                  w2_scale=w2_scales,
                                  w1_zp=w1_qzeros if has_zp else None,
                                  w2_zp=w2_qzeros if has_zp else None,
                                  block_shape=[0, group_size])
345
346
347
348
349
350
        torch_output = torch_moe(a,
                                 w1_ref,
                                 w2_ref,
                                 score,
                                 topk,
                                 expert_map=e_map)
351

352
353
354
    torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)


355
356
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
357
@pytest.mark.parametrize("padding", [True, False])
358
359
@pytest.mark.parametrize(
    "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
360
@torch.inference_mode()
361
362
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
                     monkeypatch):
363
364
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
365

366
367
368
369
    # clear the cache before every test
    from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
        is_rocm_aiter_moe_enabled)
    is_rocm_aiter_moe_enabled.cache_clear()
370
371
372
    if use_rocm_aiter:
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

373
374
375
        if dtype == torch.float32:
            pytest.skip("AITER ROCm test skip for float32")

bnellnm's avatar
bnellnm committed
376
377
378
379
380
381
382
    monkeypatch.setenv('RANK', "0")
    monkeypatch.setenv('LOCAL_RANK', "0")
    monkeypatch.setenv('WORLD_SIZE', "1")
    monkeypatch.setenv('MASTER_ADDR', 'localhost')
    monkeypatch.setenv('MASTER_PORT', '12345')
    init_distributed_environment()

383
    # Instantiate our and huggingface's MoE blocks
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    vllm_config.compilation_config.static_forward_context = dict()
    with (set_current_vllm_config(vllm_config),
          set_forward_context(None, vllm_config)):
        config = MixtralConfig()
        hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
        vllm_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            params_dtype=dtype,
            tp_size=1,
            dp_size=1,
        ).cuda()

        # Load the weights
        vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
        for i in range(config.num_local_experts):
            weights = (hf_moe.experts[i].w1.weight.data,
                       hf_moe.experts[i].w3.weight.data)
            vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
            vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data

        # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
        hf_inputs = torch.randn(
            (1, 64, config.hidden_size)).to(dtype).to("cuda")
        # vLLM uses 1D query [num_tokens, hidden_dim]
        vllm_inputs = hf_inputs.flatten(0, 1)
412

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        # Pad the weight if moe padding is enabled
        if padding:
            vllm_moe.experts.w13_weight = Parameter(F.pad(
                vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
                                                                      0:-128],
                                                    requires_grad=False)
            torch.cuda.empty_cache()
            vllm_moe.experts.w2_weight = Parameter(F.pad(
                vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
                                                                     0:-128],
                                                   requires_grad=False)
            torch.cuda.empty_cache()

        # Run forward passes for both MoE blocks
        hf_states, _ = hf_moe.forward(hf_inputs)
        vllm_states = vllm_moe.forward(vllm_inputs)
429
430
431
432
433
434
435

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

436
437
438
439
440
441
442
443
444
445
446
447
    if use_rocm_aiter:
        # The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
        # https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174  # noqa: E501
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=0.01,
                                   atol=100)
    else:
        torch.testing.assert_close(hf_states.flatten(0, 1),
                                   vllm_states,
                                   rtol=mixtral_moe_tol[dtype],
                                   atol=mixtral_moe_tol[dtype])
448
449


450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def marlin_moe_generate_valid_test_cases():
    import itertools
    m_list = [1, 123, 666]
    n_list = [128, 1024]
    k_list = [256, 2048]
    e_list = [4, 12]
    topk_list = [2, 3]
    ep_size_list = [1, 4]
    dtype_list = [torch.half, torch.bfloat16]
    group_size_list = [-1, 16, 32, 128]
    act_order_list = [True, False]
    quant_type_list = [
        scalar_types.float4_e2m1f,
        scalar_types.float8_e4m3fn,
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.uint8b128,
    ]
    is_k_full_list = [True, False]

    all_combinations = itertools.product(m_list, n_list, k_list, e_list,
                                         topk_list, ep_size_list, dtype_list,
                                         group_size_list, act_order_list,
                                         quant_type_list, is_k_full_list)

    def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order,
                   quant_type, is_k_full):

        if quant_type == scalar_types.float8_e4m3fn and \
                group_size not in [-1, 128]:
            return False
481
482
483
484
485
        if quant_type == scalar_types.float4_e2m1f:
            if group_size not in [16, 32]:
                return False
            if dtype == torch.float16 and group_size == 32:
                return False
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        if quant_type != scalar_types.float4_e2m1f and group_size == 16:
            return False

        # Filter act_order
        if act_order:
            if group_size in (-1, k, n):
                return False
            if quant_type not in [scalar_types.uint4b8]:
                return False
        elif not is_k_full:
            return False

        return True

    cases = []
    for case in all_combinations:
        if is_invalid(*case):
            cases.append(case)
    return cases


507
@pytest.mark.flaky(reruns=2)
508
509
510
@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size,"
                          "act_order, quant_type, is_k_full"),
                         marlin_moe_generate_valid_test_cases())
511
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
512
513
514
515
516
517
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
518
519
    ep_size: int,
    dtype: torch.dtype,
520
521
    group_size: int,
    act_order: bool,
522
    quant_type: ScalarType,
523
    is_k_full: bool,
524
):
525
526
527
    torch.cuda.manual_seed(0)
    has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]

528
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
529
530
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
531

532
533
534
535
536
537
538
539
540
541
    if ep_size > 1:
        local_e = e // ep_size
        e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
        e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
        e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
        w1 = w1[e_ids]
        w2 = w2[e_ids]
    else:
        e_map = None

542
543
544
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
545
    global_scale1_l = []
546
    zeros1_l = []
547
548
549
550
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
551
        if quant_type == scalar_types.float4_e2m1f:
552
553
554
555
556
557
558
            if group_size == 16:
                w_ref1, qweight1, scales1, global_scale1 = \
                    rand_marlin_weight_nvfp4_like(w1[i], group_size)
            else:
                w_ref1, qweight1, scales1 = \
                    rand_marlin_weight_mxfp4_like(w1[i], group_size)
                global_scale1 = None
559
560
561
562

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
563
564
            if global_scale1 is not None:
                global_scale1_l.append(global_scale1)
565
566
567
568
569
570
571
        elif quant_type == scalar_types.float8_e4m3fn:
            w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(
                w1[i], group_size)
            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
        elif has_zp:
572
573
574
575
576
577
578
            w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
                w1[i].transpose(1, 0), quant_type, group_size)

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            zeros1_l.append(zeros1)
579
        else:
580
            test_perm = torch.randperm(k)
581
582
583
            w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
                marlin_quantize(w1[i].transpose(1, 0), quant_type,
                                group_size, act_order, test_perm)
584
585
586
587
588
589

            w_ref1_l.append(w_ref1.T)
            qweight1_l.append(qweight1)
            scales1_l.append(scales1)
            g_idx1_l.append(g_idx1)
            sort_indices1_l.append(sort_indices1)
590
591
592
593

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
594
    global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
595
596
597
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
598
599
600
601

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
602
    global_scale2_l = []
603
    zeros2_l = []
604
605
606
607
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
608
        if quant_type == scalar_types.float4_e2m1f:
609
610
611
612
613
614
615
            if group_size == 16:
                w_ref2, qweight2, scales2, global_scale2 = \
                    rand_marlin_weight_nvfp4_like(w2[i], group_size)
            else:
                w_ref2, qweight2, scales2 = \
                    rand_marlin_weight_mxfp4_like(w2[i], group_size)
                global_scale2 = None
616
617
618
619

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
620
621
            if global_scale2 is not None:
                global_scale2_l.append(global_scale2)
622
623
624
625
626
627
628
        elif quant_type == scalar_types.float8_e4m3fn:
            w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(
                w2[i], group_size)
            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
        elif has_zp:
629
630
631
632
633
634
635
            w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
                w2[i].transpose(1, 0), quant_type, group_size)

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            zeros2_l.append(zeros2)
636
        else:
637
            test_perm = torch.randperm(n)
638
639
640
            w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
                marlin_quantize(w2[i].transpose(1, 0), quant_type,
                                group_size, act_order, test_perm)
641
642
643
644
645
646

            w_ref2_l.append(w_ref2.T)
            qweight2_l.append(qweight2)
            scales2_l.append(scales2)
            g_idx2_l.append(g_idx2)
            sort_indices2_l.append(sort_indices2)
647
648
649
650

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
651
    global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
652
653
654
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
655
656
657

    score = torch.randn((m, e), device="cuda", dtype=dtype)

658
    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
659

660
    with set_current_vllm_config(vllm_config):
661
662
663
664
665
666
        torch_output = torch_moe(a,
                                 w_ref1,
                                 w_ref2,
                                 score,
                                 topk,
                                 expert_map=e_map)
667

668
    marlin_output = torch.ops.vllm.fused_marlin_moe(
669
670
671
        a,
        qweight1,
        qweight2,
672
673
        None,
        None,
674
675
        scales1,
        scales2,
676
677
678
        score,
        topk_weights,
        topk_ids,
679
680
        global_num_experts=e,
        expert_map=e_map,
681
682
        global_scale1=global_scale1,
        global_scale2=global_scale2,
683
684
        g_idx1=g_idx1,
        g_idx2=g_idx2,
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
        w1_zeros=zeros1,
        w2_zeros=zeros2,
        quant_type_id=quant_type.id,
        is_k_full=is_k_full)

    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
def test_fused_marlin_moe_with_bias(m):
    torch.cuda.manual_seed(0)

    e, topk = 32, 4
    n, k = 2048, 2048
    group_size = 128
    act_order = False
    is_k_full = True
    quant_type = scalar_types.uint4b8
    dtype = torch.half

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
    b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10

    b_bias1_l = []
    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
        test_perm = torch.randperm(k)
        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \
            marlin_quantize(w1[i].transpose(1, 0), quant_type,
                            group_size, act_order, test_perm)

        w_ref1_l.append(w_ref1.T)
        qweight1_l.append(qweight1)
        scales1_l.append(scales1)
        g_idx1_l.append(g_idx1)
        sort_indices1_l.append(sort_indices1)
        b_bias1_l.append(marlin_permute_bias(b_bias1[i]))

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
    global_scale1 = None
    g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
    zeros1 = None
    sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
    marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None

    b_bias2_l = []
    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
        test_perm = torch.randperm(n)
        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \
            marlin_quantize(w2[i].transpose(1, 0), quant_type,
                            group_size, act_order, test_perm)

        w_ref2_l.append(w_ref2.T)
        qweight2_l.append(qweight2)
        scales2_l.append(scales2)
        g_idx2_l.append(g_idx2)
        sort_indices2_l.append(sort_indices2)
        b_bias2_l.append(marlin_permute_bias(b_bias2[i]))

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
    global_scale2 = None
    g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
    zeros2 = None
    sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
    marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

    with set_current_vllm_config(vllm_config):
        torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1,
                                 b_bias2)

    marlin_output = torch.ops.vllm.fused_marlin_moe(
        a,
        qweight1,
        qweight2,
        marlin_bias1,
        marlin_bias2,
        scales1,
        scales2,
        score,
        topk_weights,
        topk_ids,
        global_num_experts=e,
        expert_map=None,
        global_scale1=global_scale1,
        global_scale2=global_scale2,
        g_idx1=g_idx1,
        g_idx2=g_idx2,
798
799
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
800
801
        w1_zeros=zeros1,
        w2_zeros=zeros2,
802
        quant_type_id=quant_type.id,
803
        is_k_full=is_k_full)
804

805
    torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
    topk_ids = torch.randint(0,
                             num_experts, (3, 4),
                             dtype=torch.int32,
                             device='cuda')

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty((max_num_tokens_padded, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    num_tokens_post_pad = torch.empty((1),
                                      dtype=torch.int32,
                                      device=topk_ids.device)

829
    opcheck(torch.ops._moe_C.moe_align_block_size,
830
831
            (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
             num_tokens_post_pad))
bnellnm's avatar
bnellnm committed
832
833


834
@pytest.mark.parametrize("m", [1, 33, 64, 222])
bnellnm's avatar
bnellnm committed
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
    input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
    actual = torch.empty((m, k), device="cuda", dtype=dtype)

    expected = input.sum(dim=1)
    torch.ops._moe_C.moe_sum(input, actual)

    torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0)

    opcheck(torch.ops._moe_C.moe_sum, (input, actual))