test_pplx_moe.py 29.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_pplx_moe.py`.
"""
7

8
import copy
9
10
11
import itertools
import textwrap
import traceback
12
from collections.abc import Callable
13
14
15
16
17
18

import pytest
import torch

try:
    from pplx_kernels import AllToAll
19
20
21
22
23
24
25
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

26
27
28
29
    has_pplx = True
except ImportError:
    has_pplx = False

30
31
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
32
    make_dummy_moe_config,
33
34
35
36
    make_shared_experts,
    make_test_weights,
    naive_batched_moe,
)
37
from tests.kernels.quant_utils import dequant
38
from tests.kernels.utils import torch_experts
39
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
40
41
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
42
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
bnellnm's avatar
bnellnm committed
43
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
44
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
45
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
46
47
    TopKWeightAndReduceDelegate,
)
48
from vllm.utils.math_utils import round_up
49
from vllm.utils.torch_utils import set_random_seed
50
from vllm.v1.worker.workspace import init_workspace_manager
51

52
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
53
from .parallel_utils import ProcessGroupInfo, parallel_launch
54

55
56
57
58
59
requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

60
61
62
63
64
65
66
67
BATCHED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (33, 2048, 128),
    (64, 128, 2048),
    (222, 128, 128),
    (222, 2048, 1024),
]

68
PPLX_COMBOS = [
69
    # TODO(bnell): figure out why this fails, seems to be test problem
70
    # (1, 128, 128),
71
72
    (2, 128, 512),
    (3, 1024, 2048),
73
74
    (4, 128, 128),
    (32, 1024, 512),
75
    (45, 512, 2048),
76
77
78
    (64, 1024, 512),
    (222, 2048, 1024),
    (256, 1408, 2048),
79
80
81
82
]

NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
83
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
84
85
86
87
88
89
90
91

vllm_config = VllmConfig()


def torch_prepare(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
92
    max_num_tokens: int | None = None,
93
94
95
96
97
98
99
) -> tuple[torch.Tensor, torch.Tensor]:
    assert topk_ids.dim() == 2
    assert topk_ids.shape[0] == a.shape[0]

    num_tokens, hidden_dim = a.shape
    topk = topk_ids.shape[1]

100
    tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
101
102
103
104
105
106

    assert tokens_per_expert.numel() == num_experts

    if max_num_tokens is None:
        max_num_tokens = int(tokens_per_expert.max().item())

107
108
109
    b_a = torch.zeros(
        (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
    )
110
111
112
113
114
115
116

    token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)

    for token in range(num_tokens):
        for j in range(topk):
            expert_id = topk_ids[token, j]
            idx = token_counts[expert_id]
117
            b_a[expert_id, idx : idx + 1, :] = a[token, :]
118
119
120
121
122
            token_counts[expert_id] = token_counts[expert_id] + 1

    return b_a, tokens_per_expert


123
124
125
def torch_finalize(
    b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
) -> torch.Tensor:
126
127
128
129
    num_tokens = topk_ids.shape[0]
    num_experts = b_out.shape[0]
    K = b_out.shape[-1]
    out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
130
    expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
131
132
133
134
135
    for token in range(num_tokens):
        expert_ids = topk_ids[token]
        for i in range(expert_ids.numel()):
            expert_id = expert_ids[i]
            idx = expert_counts[expert_id]
136
137
138
139
            out[token, :] = (
                out[token, :]
                + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
            )
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            expert_counts[expert_id] = expert_counts[expert_id] + 1

    return out


def torch_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor:
    num_experts = w1.shape[0]
    b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
    assert b_a.dim() == 3
    num_tokens, topk = topk_ids.shape
    _, max_num_tokens, K = b_a.shape
    assert num_experts == b_a.shape[0] and w2.shape[1] == K
158
159
160
161
162
163
    out = torch.zeros(
        (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
    )
    tmp = torch.empty(
        (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
    )
164
165
166
167
    for expert in range(num_experts):
        num = tokens_per_expert[expert]
        if num > 0:
            torch.ops._C.silu_and_mul(
168
169
                tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
            )
170
171
172
173
174
            out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)

    return torch_finalize(out, topk_weight, topk_ids)


175
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
176
177
178
179
180
181
182
183
184
185
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
186
    workspace_init,
187
):
188
    set_random_seed(7)
189
190
191
192
193
194
195
196

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
197
198
199
        baseline_output = torch_experts(
            a, w1, w2, topk_weight, topk_ids
        )  # only for baseline
200
        torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
201
        batched_output = naive_batched_moe(
202
203
            a, w1, w2, topk_weight, topk_ids
        )  # pick torch_experts or this
204

205
206
    torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
    torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
207
208


209
210
211
212
213
214
215
216
217
def create_pplx_prepare_finalize(
    num_tokens: int,
    hidden_dim: int,
    topk: int,
    num_experts: int,
    rank: int,
    dp_size: int,
    world_size: int,
    in_dtype: torch.dtype,
218
219
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
220
    per_act_token_quant: bool,
221
    group_name: str | None,
222
223
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
224
225
226
        PplxPrepareAndFinalize,
        pplx_hidden_dim_scale_bytes,
    )
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
    num_local_experts = rank_chunk(num_experts, 0, world_size)

    hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
        max_num_tokens,
        hidden_dim,
        in_dtype,
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
    )

    args = dict(
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
        world_size=world_size,
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim_bytes,
        hidden_dim_scale_bytes=scale_bytes,
    )

    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

    prepare_finalize = PplxPrepareAndFinalize(
        ata,
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
        num_dispatchers=world_size // dp_size,
    )

    return prepare_finalize, ata


268
269
270
271
272
273
274
def rank_chunk(num: int, r: int, w: int) -> int:
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
    chunk = rank_chunk(t.shape[0], r, w)
275
    return t[(r * chunk) : (r + 1) * chunk]
276
277


278
def maybe_chunk_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
279
280
281
282
283
284
    if t is not None:
        return chunk_by_rank(t, r, w)
    else:
        return t


285
def chunk_scales_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
286
287
    if t is not None and t.numel() > 1:
        chunk = rank_chunk(t.shape[0], r, w)
288
        return t[(r * chunk) : (r + 1) * chunk]
289
290
291
292
    else:
        return t


293
def chunk_scales(t: torch.Tensor | None, start: int, end: int) -> torch.Tensor | None:
294
295
296
297
298
299
300
301
302
303
    if t is not None and t.numel() > 1:
        return t[start:end]
    else:
        return t


def dummy_work(a: torch.Tensor) -> torch.Tensor:
    return a * 1.1


304
305
306
307
308
309
310
def pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
311
312
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
313
    per_act_token_quant: bool,
314
    group_name: str | None,
315
) -> torch.Tensor:
316
317
318
319
320
321
322
    assert torch.cuda.current_device() == pgi.local_rank

    topk = topk_ids.shape[1]
    num_tokens, hidden_dim = a.shape
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
323

324
325
    topk_ids = topk_ids.to(dtype=torch.uint32)

326
327
328
329
330
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
        hidden_dim,
        topk,
        num_experts,
331
332
        rank,
        dp_size,
333
334
335
336
337
338
        world_size,
        a.dtype,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
339
340
    )

341
342
    assert a.shape[0] == topk_ids.shape[0]

343
344
345
346
    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)

347
348
349
350
351
352
353
354
355
    assert a_chunk.shape[0] == chunk_topk_ids.shape[0]

    out = torch.full(
        a_chunk.shape,
        torch.nan,
        dtype=a.dtype,
        device=device,
    )

356
    if quant_dtype is not None and not per_act_token_quant and block_shape is None:
357
358
359
360
361
362
        a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None

363
    b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
364
365
366
367
368
369
        a_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        num_experts,
        None,
        False,
370
        FusedMoEQuantConfig.make(
371
            quant_dtype,
372
373
374
375
376
            per_act_token_quant=per_act_token_quant,
            per_out_ch_quant=False,
            block_shape=block_shape,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
377
        ),
378
379
    )

380
    b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
381
382
383
384
385
386
387

    prepare_finalize.finalize(
        out,
        b_a,
        chunk_topk_weight,
        chunk_topk_ids,
        False,
388
        weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    )

    torch.cuda.synchronize()

    ata.destroy()

    num_tokens = a_chunk.shape[0]

    return out[:num_tokens]


def _pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    score: torch.Tensor,
    topk: torch.Tensor,
    num_experts: int,
407
408
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
409
    per_act_token_quant: bool,
410
    use_internode: bool,
411
):
412
413
    try:
        if use_internode:
414
415
416
417
418
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
419
420
421
422
423
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
424
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
425
            group_name = cpu_group.group_name
426

427
428
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
        m, k = a.shape
429

430
        a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        torch_output = (
            a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
        ).sum(dim=1)

        pplx_output = pplx_prepare_finalize(
            pgi,
            dp_size,
            a,
            topk_weight,
            topk_ids,
            num_experts,
            quant_dtype,
            block_shape,
            per_act_token_quant,
            group_name,
        )
448

449
450
451
        torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
            pgi.device
        )
452

453
        torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
454
455
456
    finally:
        if use_internode:
            nvshmem_finalize()
457
458


459
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
460
461
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
462
@pytest.mark.parametrize("dtype", DTYPES)
463
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
464
465
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
466
@pytest.mark.parametrize("use_internode", [False])
467
@pytest.mark.optional
468
@requires_pplx
469
@multi_gpu_test(num_gpus=2)
470
def test_pplx_prepare_finalize_slow(
471
472
473
474
475
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
476
    per_act_token_quant: bool,
477
    block_shape: list[int] | None,
478
    use_internode: bool,
479
):
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        act_dtype = dtype
        quant_dtype = None

    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

495
    set_random_seed(7)
496
497
498
    m, n, k = mnk
    world_size, dp_size = world_dp_size
    device = "cuda"
499
500
501

    a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
    score = torch.randn((m, e), device=device, dtype=act_dtype)
502

503
504
505
506
507
508
509
510
511
512
513
514
515
    parallel_launch(
        world_size,
        _pplx_prepare_finalize,
        dp_size,
        a,
        score,
        topk,
        e,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        use_internode,
    )
516
517
518


def pplx_moe(
519
    group_name: str | None,
520
521
522
523
524
525
526
527
    rank: int,
    world_size: int,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
528
529
530
531
532
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
533
    per_act_token_quant=False,
534
    block_shape: list[int] | None = None,
535
    use_compile: bool = False,
536
    use_cudagraphs: bool = True,
537
538
    shared_experts: torch.nn.Module | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
539
    num_tokens, hidden_dim = a.shape
540
541
    num_experts = w1.shape[0]
    topk = topk_ids.shape[1]
542
    max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
bnellnm's avatar
bnellnm committed
543

544
545
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
bnellnm's avatar
bnellnm committed
546
        hidden_dim,
547
548
549
550
551
        topk,
        num_experts,
        rank,
        dp_size,
        world_size,
bnellnm's avatar
bnellnm committed
552
        a.dtype,
553
554
555
556
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
bnellnm's avatar
bnellnm committed
557
    )
558
559
560
561

    topk_ids = topk_ids.to(dtype=torch.uint32)

    # Note: workers with the same dp_rank must use the exact same inputs.
562
563
564
    a_chunk = chunk_by_rank(a, rank, world_size)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
565
566

    # Chunking weights like this only works for batched format
567
568
569
570
571
572
    w1_chunk = chunk_by_rank(w1, rank, world_size)
    w2_chunk = chunk_by_rank(w2, rank, world_size)
    w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
    w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
    a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
    a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
bnellnm's avatar
bnellnm committed
573

574
575
576
577
578
579
580
581
582
583
584
585
586
587
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant,
        w1_scale=w1_scale_chunk,
        w2_scale=w2_scale_chunk,
        a1_scale=a1_scale_chunk,
        a2_scale=a2_scale_chunk,
    )

    experts = BatchedTritonExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=prepare_finalize.num_dispatchers(),
        quant_config=quant_config,
588
        moe_config=make_dummy_moe_config(),
589
590
591
592
593
594
    )

    fused_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
        shared_experts,
595
        inplace=False,
596
597
    )

598
599
600
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
601
    if use_compile:
602
603
604
        _fused_experts = torch.compile(
            fused_experts, backend="inductor", fullgraph=True
        )
605
606
607
        torch._dynamo.mark_dynamic(a_chunk, 0)
        torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
        torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
608
609
610
    else:
        _fused_experts = fused_experts

611
612
613
614
615
616
617
618
    out = _fused_experts(
        a_chunk,
        w1_chunk,
        w2_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
    )
619
620

    if use_cudagraphs:
621
622
623
624
625
        if isinstance(out, tuple):
            out[0].fill_(0)
            out[1].fill_(0)
        else:
            out.fill_(0)
626
627
628
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
629
630
631
632
633
634
635
636
            out = _fused_experts(
                a_chunk,
                w1_chunk,
                w2_chunk,
                chunk_topk_weight,
                chunk_topk_ids,
                global_num_experts=num_experts,
            )
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655

        torch.cuda.synchronize()
        graph.replay()

    torch.cuda.synchronize()

    ata.destroy()

    return out


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
656
    num_experts: int,
657
658
659
    w1_s: torch.Tensor | None = None,
    w2_s: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
660
    per_act_token_quant: bool = False,
661
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
662
    use_internode: bool = False,
663
    shared_experts: torch.nn.Module | None = None,
664
):
665
666
    try:
        if use_internode:
667
668
669
670
671
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
672
673
674
675
676
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
677
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
            group_name = cpu_group.group_name

        m, k = a.shape
        e, _, n = w2.shape

        moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)

        device = torch.device("cuda", pgi.rank)
        rank = pgi.rank
        world_size = pgi.world_size

        a = a.to(device)
        w1 = w1.to(device)
        w2 = w2.to(device)
        w1_s = w1_s.to(device) if w1_s is not None else None
        w2_s = w2_s.to(device) if w2_s is not None else None

695
        if quant_dtype is not None and not per_act_token_quant and block_shape is None:
696
697
698
699
700
701
702
703
704
            a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
            a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        else:
            a1_scale = None
            a2_scale = None

        with set_current_vllm_config(vllm_config), override_config(moe_config):
            topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)

705
            shared_output = shared_experts(a) if shared_experts is not None else None
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
            torch_output = torch_experts(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

            batched_output = naive_batched_moe(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

737
            pplx_outputs = pplx_moe(
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
                group_name,
                rank,
                world_size,
                dp_size,
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
754
                shared_experts=shared_experts,
755
756
            )

757
758
759
760
761
762
763
764
765
766
        if shared_experts is None:
            pplx_shared_output = None
            pplx_output = pplx_outputs
            assert isinstance(pplx_output, torch.Tensor)
        else:
            pplx_shared_output, pplx_output = pplx_outputs

        if shared_output is not None:
            assert pplx_shared_output is not None
            chunked_shared_output = chunk_by_rank(
767
768
                shared_output, pgi.rank, pgi.world_size
            ).to(pplx_shared_output.device)
769
770
771
        else:
            chunked_shared_output = None

772
        chunked_batch_output = chunk_by_rank(
773
774
            batched_output, pgi.rank, pgi.world_size
        ).to(pplx_output.device)
775

776
        torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
777

778
779
780
        torch.testing.assert_close(
            pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
        )
781
782
783
784

        if shared_experts is not None:
            assert chunked_shared_output is not None
            assert pplx_shared_output is not None
785
786
787
            torch.testing.assert_close(
                pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
            )
788

789
790
791
792
793
794
    finally:
        if use_internode:
            nvshmem_finalize()


@pytest.mark.parametrize("mnk", PPLX_COMBOS)
795
796
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
797
@pytest.mark.parametrize("dtype", DTYPES)
798
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
bnellnm's avatar
bnellnm committed
799
800
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
801
@pytest.mark.parametrize("use_internode", [False])
802
@pytest.mark.optional
803
@requires_pplx
804
@multi_gpu_test(num_gpus=2)
805
def test_pplx_moe_slow(
806
807
808
809
810
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
bnellnm's avatar
bnellnm committed
811
    per_act_token_quant: bool,
812
    block_shape: list[int] | None,
813
    use_internode: bool,
814
):
815
    set_random_seed(7)
816
817
    m, n, k = mnk
    world_size, dp_size = world_dp_size
bnellnm's avatar
bnellnm committed
818
819
820
821
822
823
824
825

    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        quant_dtype = None

826
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
bnellnm's avatar
bnellnm committed
827
828
        pytest.skip("Skip quantization test for non-quantized type")

829
830
831
    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

bnellnm's avatar
bnellnm committed
832
833
834
    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

835
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
836
837
838
839
840
        e,
        n,
        k,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
841
        per_out_ch_quant=per_act_token_quant,
842
    )
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    parallel_launch(
        world_size,
        _pplx_moe,
        dp_size,
        a,
        w1,
        w2,
        score,
        topk,
        e,
        w1_s,
        w2_s,
        quant_dtype,
        per_act_token_quant,
        block_shape,
        use_internode,
    )
861
862


863
864
865
866
867
868
869
870
def _pplx_test_loop(
    pgi: ProcessGroupInfo,
    dp_size: int,
    use_internode: bool,
    use_shared_experts: bool,
    make_weights: bool,
    test_fn: Callable,
):
871
872
873
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

874
875
876
877
878
879
880
881
882
883
884
885
886
    def format_result(msg, ex=None):
        if ex is not None:
            x = str(ex)
            newx = x.strip(" \n\t")[:16]
            if len(newx) < len(x):
                newx = newx + " ..."

            prefix = "E\t"
            print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
            print(f"FAILED {msg} - {newx}\n")
        else:
            print(f"PASSED {msg}")

887
888
889
890
891
    if use_shared_experts:
        # Note: this config is only needed for the non-naive shared experts.
        new_vllm_config = copy.deepcopy(vllm_config)
        new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
        new_vllm_config.parallel_config.enable_expert_parallel = True
892
        _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
893

894
    set_random_seed(7)
895
896
897
    combos = itertools.product(
        PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
    )
898
899
900
901
902
903
904
905
906
907
908
909
910
    exceptions = []
    count = 0
    for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
        count = count + 1
        m, n, k = mnk

        if dtype == torch.float8_e4m3fn:
            use_fp8_w8a8 = True
            quant_dtype = dtype
        else:
            use_fp8_w8a8 = False
            quant_dtype = None

911
912
913
914
        test_desc = (
            f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
            f"dtype={dtype}, per_act_token={per_act_token_quant}, "
            f"block_shape={block_shape}, use_internode={use_internode}, "
915
916
            f"use_shared_experts={use_shared_experts}"
        )
917

918
919
        if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
            print(f"{test_desc} - Skip quantization test for non-quantized type.")
920
921
922
923
924
925
926
927
928
929
930
            continue

        if per_act_token_quant and block_shape is not None:
            print(f"{test_desc} - Skip illegal quantization combination.")
            continue

        a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

        args = dict()
        if make_weights:
931
            (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
932
933
934
935
936
                e,
                n,
                k,
                quant_dtype=quant_dtype,
                block_shape=block_shape,
937
                per_out_ch_quant=per_act_token_quant,
938
939
940
941
942
943
            )
            args["w1"] = w1
            args["w2"] = w2
            args["w1_s"] = w1_s
            args["w2_s"] = w2_s

944
945
946
947
948
949
950
951
        if use_shared_experts:
            args["shared_experts"] = make_shared_experts(
                n,
                k,
                in_dtype=a.dtype,
                quant_dtype=quant_dtype,
            )

952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
        try:
            test_fn(
                pgi=pgi,
                dp_size=dp_size,
                a=a,
                score=score,
                topk=topk,
                num_experts=e,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
                use_internode=use_internode,
                **args,
            )
            format_result(test_desc)
        except Exception as ex:
            format_result(test_desc, ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
974
975
            f"rank={pgi.rank}."
        )
976
    else:
977
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
978
979
980
981
982


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
983
@multi_gpu_test(num_gpus=2)
984
985
986
987
def test_pplx_prepare_finalize(
    world_dp_size: tuple[int, int],
    use_internode: bool,
):
988
    set_random_seed(7)
989
    world_size, dp_size = world_dp_size
990
991
992
993
994
995
996
997
998
    parallel_launch(
        world_size * dp_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        False,
        False,
        _pplx_prepare_finalize,
    )
999
1000
1001
1002


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
1003
@pytest.mark.parametrize("use_shared_experts", [False, True])
1004
@requires_pplx
1005
@multi_gpu_test(num_gpus=2)
1006
1007
1008
def test_pplx_moe(
    world_dp_size: tuple[int, int],
    use_internode: bool,
1009
    use_shared_experts: bool,
1010
):
1011
    set_random_seed(7)
1012
    world_size, dp_size = world_dp_size
1013
1014
1015
1016
1017
1018
1019
1020
1021
    parallel_launch(
        world_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        use_shared_experts,
        True,
        _pplx_moe,
    )