test_pplx_moe.py 29.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_pplx_moe.py`.
"""
7

8
import copy
9
10
11
import itertools
import textwrap
import traceback
12
from collections.abc import Callable
13
14
15
16
17
18

import pytest
import torch

try:
    from pplx_kernels import AllToAll
19
20
21
22
23
24
25
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

26
27
28
29
    has_pplx = True
except ImportError:
    has_pplx = False

30
31
32
33
34
35
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
    make_shared_experts,
    make_test_weights,
    naive_batched_moe,
)
36
from tests.kernels.quant_utils import dequant
37
from tests.kernels.utils import torch_experts
38
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
39
40
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
41
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
bnellnm's avatar
bnellnm committed
42
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
43
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
44
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
45
46
    TopKWeightAndReduceDelegate,
)
47
from vllm.utils.math_utils import round_up
48
from vllm.utils.torch_utils import set_random_seed
49
from vllm.v1.worker.workspace import init_workspace_manager
50

51
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
52
from .parallel_utils import ProcessGroupInfo, parallel_launch
53

54
55
56
57
58
requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

59
60
61
62
63
64
65
66
BATCHED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (33, 2048, 128),
    (64, 128, 2048),
    (222, 128, 128),
    (222, 2048, 1024),
]

67
PPLX_COMBOS = [
68
    # TODO(bnell): figure out why this fails, seems to be test problem
69
    # (1, 128, 128),
70
71
    (2, 128, 512),
    (3, 1024, 2048),
72
73
    (4, 128, 128),
    (32, 1024, 512),
74
    (45, 512, 2048),
75
76
77
    (64, 1024, 512),
    (222, 2048, 1024),
    (256, 1408, 2048),
78
79
80
81
]

NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
82
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
83
84
85
86
87
88
89
90

vllm_config = VllmConfig()


def torch_prepare(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
91
    max_num_tokens: int | None = None,
92
93
94
95
96
97
98
) -> tuple[torch.Tensor, torch.Tensor]:
    assert topk_ids.dim() == 2
    assert topk_ids.shape[0] == a.shape[0]

    num_tokens, hidden_dim = a.shape
    topk = topk_ids.shape[1]

99
    tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
100
101
102
103
104
105

    assert tokens_per_expert.numel() == num_experts

    if max_num_tokens is None:
        max_num_tokens = int(tokens_per_expert.max().item())

106
107
108
    b_a = torch.zeros(
        (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
    )
109
110
111
112
113
114
115

    token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)

    for token in range(num_tokens):
        for j in range(topk):
            expert_id = topk_ids[token, j]
            idx = token_counts[expert_id]
116
            b_a[expert_id, idx : idx + 1, :] = a[token, :]
117
118
119
120
121
            token_counts[expert_id] = token_counts[expert_id] + 1

    return b_a, tokens_per_expert


122
123
124
def torch_finalize(
    b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
) -> torch.Tensor:
125
126
127
128
    num_tokens = topk_ids.shape[0]
    num_experts = b_out.shape[0]
    K = b_out.shape[-1]
    out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
129
    expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
130
131
132
133
134
    for token in range(num_tokens):
        expert_ids = topk_ids[token]
        for i in range(expert_ids.numel()):
            expert_id = expert_ids[i]
            idx = expert_counts[expert_id]
135
136
137
138
            out[token, :] = (
                out[token, :]
                + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
            )
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            expert_counts[expert_id] = expert_counts[expert_id] + 1

    return out


def torch_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor:
    num_experts = w1.shape[0]
    b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
    assert b_a.dim() == 3
    num_tokens, topk = topk_ids.shape
    _, max_num_tokens, K = b_a.shape
    assert num_experts == b_a.shape[0] and w2.shape[1] == K
157
158
159
160
161
162
    out = torch.zeros(
        (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
    )
    tmp = torch.empty(
        (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
    )
163
164
165
166
    for expert in range(num_experts):
        num = tokens_per_expert[expert]
        if num > 0:
            torch.ops._C.silu_and_mul(
167
168
                tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
            )
169
170
171
172
173
            out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)

    return torch_finalize(out, topk_weight, topk_ids)


174
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
175
176
177
178
179
180
181
182
183
184
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
185
    workspace_init,
186
):
187
    set_random_seed(7)
188
189
190
191
192
193
194
195

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
196
197
198
        baseline_output = torch_experts(
            a, w1, w2, topk_weight, topk_ids
        )  # only for baseline
199
        torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
200
        batched_output = naive_batched_moe(
201
202
            a, w1, w2, topk_weight, topk_ids
        )  # pick torch_experts or this
203

204
205
    torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
    torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
206
207


208
209
210
211
212
213
214
215
216
def create_pplx_prepare_finalize(
    num_tokens: int,
    hidden_dim: int,
    topk: int,
    num_experts: int,
    rank: int,
    dp_size: int,
    world_size: int,
    in_dtype: torch.dtype,
217
218
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
219
    per_act_token_quant: bool,
220
    group_name: str | None,
221
222
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
223
224
225
        PplxPrepareAndFinalize,
        pplx_hidden_dim_scale_bytes,
    )
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
    num_local_experts = rank_chunk(num_experts, 0, world_size)

    hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
        max_num_tokens,
        hidden_dim,
        in_dtype,
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
    )

    args = dict(
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
        world_size=world_size,
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim_bytes,
        hidden_dim_scale_bytes=scale_bytes,
    )

    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

    prepare_finalize = PplxPrepareAndFinalize(
        ata,
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
        num_dispatchers=world_size // dp_size,
    )

    return prepare_finalize, ata


267
268
269
270
271
272
273
def rank_chunk(num: int, r: int, w: int) -> int:
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
    chunk = rank_chunk(t.shape[0], r, w)
274
    return t[(r * chunk) : (r + 1) * chunk]
275
276


277
def maybe_chunk_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
278
279
280
281
282
283
    if t is not None:
        return chunk_by_rank(t, r, w)
    else:
        return t


284
def chunk_scales_by_rank(t: torch.Tensor | None, r: int, w: int) -> torch.Tensor | None:
285
286
    if t is not None and t.numel() > 1:
        chunk = rank_chunk(t.shape[0], r, w)
287
        return t[(r * chunk) : (r + 1) * chunk]
288
289
290
291
    else:
        return t


292
def chunk_scales(t: torch.Tensor | None, start: int, end: int) -> torch.Tensor | None:
293
294
295
296
297
298
299
300
301
302
    if t is not None and t.numel() > 1:
        return t[start:end]
    else:
        return t


def dummy_work(a: torch.Tensor) -> torch.Tensor:
    return a * 1.1


303
304
305
306
307
308
309
def pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
310
311
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
312
    per_act_token_quant: bool,
313
    group_name: str | None,
314
) -> torch.Tensor:
315
316
317
318
319
320
321
    assert torch.cuda.current_device() == pgi.local_rank

    topk = topk_ids.shape[1]
    num_tokens, hidden_dim = a.shape
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
322

323
324
    topk_ids = topk_ids.to(dtype=torch.uint32)

325
326
327
328
329
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
        hidden_dim,
        topk,
        num_experts,
330
331
        rank,
        dp_size,
332
333
334
335
336
337
        world_size,
        a.dtype,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
338
339
    )

340
341
    assert a.shape[0] == topk_ids.shape[0]

342
343
344
345
    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)

346
347
348
349
350
351
352
353
354
    assert a_chunk.shape[0] == chunk_topk_ids.shape[0]

    out = torch.full(
        a_chunk.shape,
        torch.nan,
        dtype=a.dtype,
        device=device,
    )

355
    if quant_dtype is not None and not per_act_token_quant and block_shape is None:
356
357
358
359
360
361
        a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None

362
    b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
363
364
365
366
367
368
        a_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        num_experts,
        None,
        False,
369
        FusedMoEQuantConfig.make(
370
            quant_dtype,
371
372
373
374
375
            per_act_token_quant=per_act_token_quant,
            per_out_ch_quant=False,
            block_shape=block_shape,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
376
        ),
377
378
    )

379
    b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
380
381
382
383
384
385
386

    prepare_finalize.finalize(
        out,
        b_a,
        chunk_topk_weight,
        chunk_topk_ids,
        False,
387
        weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    )

    torch.cuda.synchronize()

    ata.destroy()

    num_tokens = a_chunk.shape[0]

    return out[:num_tokens]


def _pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    score: torch.Tensor,
    topk: torch.Tensor,
    num_experts: int,
406
407
    quant_dtype: torch.dtype | None,
    block_shape: list[int] | None,
408
    per_act_token_quant: bool,
409
    use_internode: bool,
410
):
411
412
    try:
        if use_internode:
413
414
415
416
417
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
418
419
420
421
422
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
423
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
424
            group_name = cpu_group.group_name
425

426
427
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
        m, k = a.shape
428

429
        a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        torch_output = (
            a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
        ).sum(dim=1)

        pplx_output = pplx_prepare_finalize(
            pgi,
            dp_size,
            a,
            topk_weight,
            topk_ids,
            num_experts,
            quant_dtype,
            block_shape,
            per_act_token_quant,
            group_name,
        )
447

448
449
450
        torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
            pgi.device
        )
451

452
        torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
453
454
455
    finally:
        if use_internode:
            nvshmem_finalize()
456
457


458
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
459
460
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
461
@pytest.mark.parametrize("dtype", DTYPES)
462
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
463
464
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
465
@pytest.mark.parametrize("use_internode", [False])
466
@pytest.mark.optional
467
@requires_pplx
468
@multi_gpu_test(num_gpus=2)
469
def test_pplx_prepare_finalize_slow(
470
471
472
473
474
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
475
    per_act_token_quant: bool,
476
    block_shape: list[int] | None,
477
    use_internode: bool,
478
):
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        act_dtype = dtype
        quant_dtype = None

    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

494
    set_random_seed(7)
495
496
497
    m, n, k = mnk
    world_size, dp_size = world_dp_size
    device = "cuda"
498
499
500

    a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
    score = torch.randn((m, e), device=device, dtype=act_dtype)
501

502
503
504
505
506
507
508
509
510
511
512
513
514
    parallel_launch(
        world_size,
        _pplx_prepare_finalize,
        dp_size,
        a,
        score,
        topk,
        e,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        use_internode,
    )
515
516
517


def pplx_moe(
518
    group_name: str | None,
519
520
521
522
523
524
525
526
    rank: int,
    world_size: int,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
527
528
529
530
531
    w1_scale: torch.Tensor | None = None,
    w2_scale: torch.Tensor | None = None,
    a1_scale: torch.Tensor | None = None,
    a2_scale: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
532
    per_act_token_quant=False,
533
    block_shape: list[int] | None = None,
534
    use_compile: bool = False,
535
    use_cudagraphs: bool = True,
536
537
    shared_experts: torch.nn.Module | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
538
    num_tokens, hidden_dim = a.shape
539
540
    num_experts = w1.shape[0]
    topk = topk_ids.shape[1]
541
    max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
bnellnm's avatar
bnellnm committed
542

543
544
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
bnellnm's avatar
bnellnm committed
545
        hidden_dim,
546
547
548
549
550
        topk,
        num_experts,
        rank,
        dp_size,
        world_size,
bnellnm's avatar
bnellnm committed
551
        a.dtype,
552
553
554
555
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
bnellnm's avatar
bnellnm committed
556
    )
557
558
559
560

    topk_ids = topk_ids.to(dtype=torch.uint32)

    # Note: workers with the same dp_rank must use the exact same inputs.
561
562
563
    a_chunk = chunk_by_rank(a, rank, world_size)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
564
565

    # Chunking weights like this only works for batched format
566
567
568
569
570
571
    w1_chunk = chunk_by_rank(w1, rank, world_size)
    w2_chunk = chunk_by_rank(w2, rank, world_size)
    w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
    w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
    a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
    a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
bnellnm's avatar
bnellnm committed
572

573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant,
        w1_scale=w1_scale_chunk,
        w2_scale=w2_scale_chunk,
        a1_scale=a1_scale_chunk,
        a2_scale=a2_scale_chunk,
    )

    experts = BatchedTritonExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=prepare_finalize.num_dispatchers(),
        quant_config=quant_config,
    )

    fused_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
        shared_experts,
    )

595
596
597
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
598
    if use_compile:
599
600
601
        _fused_experts = torch.compile(
            fused_experts, backend="inductor", fullgraph=True
        )
602
603
604
        torch._dynamo.mark_dynamic(a_chunk, 0)
        torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
        torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
605
606
607
    else:
        _fused_experts = fused_experts

608
609
610
611
612
613
614
615
    out = _fused_experts(
        a_chunk,
        w1_chunk,
        w2_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
    )
616
617

    if use_cudagraphs:
618
619
620
621
622
        if isinstance(out, tuple):
            out[0].fill_(0)
            out[1].fill_(0)
        else:
            out.fill_(0)
623
624
625
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
626
627
628
629
630
631
632
633
            out = _fused_experts(
                a_chunk,
                w1_chunk,
                w2_chunk,
                chunk_topk_weight,
                chunk_topk_ids,
                global_num_experts=num_experts,
            )
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652

        torch.cuda.synchronize()
        graph.replay()

    torch.cuda.synchronize()

    ata.destroy()

    return out


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
653
    num_experts: int,
654
655
656
    w1_s: torch.Tensor | None = None,
    w2_s: torch.Tensor | None = None,
    quant_dtype: torch.dtype | None = None,
bnellnm's avatar
bnellnm committed
657
    per_act_token_quant: bool = False,
658
    block_shape: list[int] | None = None,
bnellnm's avatar
bnellnm committed
659
    use_internode: bool = False,
660
    shared_experts: torch.nn.Module | None = None,
661
):
662
663
    try:
        if use_internode:
664
665
666
667
668
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
669
670
671
672
673
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
674
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
            group_name = cpu_group.group_name

        m, k = a.shape
        e, _, n = w2.shape

        moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)

        device = torch.device("cuda", pgi.rank)
        rank = pgi.rank
        world_size = pgi.world_size

        a = a.to(device)
        w1 = w1.to(device)
        w2 = w2.to(device)
        w1_s = w1_s.to(device) if w1_s is not None else None
        w2_s = w2_s.to(device) if w2_s is not None else None

692
        if quant_dtype is not None and not per_act_token_quant and block_shape is None:
693
694
695
696
697
698
699
700
701
            a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
            a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        else:
            a1_scale = None
            a2_scale = None

        with set_current_vllm_config(vllm_config), override_config(moe_config):
            topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)

702
            shared_output = shared_experts(a) if shared_experts is not None else None
703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
            torch_output = torch_experts(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

            batched_output = naive_batched_moe(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

734
            pplx_outputs = pplx_moe(
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
                group_name,
                rank,
                world_size,
                dp_size,
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
751
                shared_experts=shared_experts,
752
753
            )

754
755
756
757
758
759
760
761
762
763
        if shared_experts is None:
            pplx_shared_output = None
            pplx_output = pplx_outputs
            assert isinstance(pplx_output, torch.Tensor)
        else:
            pplx_shared_output, pplx_output = pplx_outputs

        if shared_output is not None:
            assert pplx_shared_output is not None
            chunked_shared_output = chunk_by_rank(
764
765
                shared_output, pgi.rank, pgi.world_size
            ).to(pplx_shared_output.device)
766
767
768
        else:
            chunked_shared_output = None

769
        chunked_batch_output = chunk_by_rank(
770
771
            batched_output, pgi.rank, pgi.world_size
        ).to(pplx_output.device)
772

773
        torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
774

775
776
777
        torch.testing.assert_close(
            pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
        )
778
779
780
781

        if shared_experts is not None:
            assert chunked_shared_output is not None
            assert pplx_shared_output is not None
782
783
784
            torch.testing.assert_close(
                pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
            )
785

786
787
788
789
790
791
    finally:
        if use_internode:
            nvshmem_finalize()


@pytest.mark.parametrize("mnk", PPLX_COMBOS)
792
793
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
794
@pytest.mark.parametrize("dtype", DTYPES)
795
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
bnellnm's avatar
bnellnm committed
796
797
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
798
@pytest.mark.parametrize("use_internode", [False])
799
@pytest.mark.optional
800
@requires_pplx
801
@multi_gpu_test(num_gpus=2)
802
def test_pplx_moe_slow(
803
804
805
806
807
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
bnellnm's avatar
bnellnm committed
808
    per_act_token_quant: bool,
809
    block_shape: list[int] | None,
810
    use_internode: bool,
811
):
812
    set_random_seed(7)
813
814
    m, n, k = mnk
    world_size, dp_size = world_dp_size
bnellnm's avatar
bnellnm committed
815
816
817
818
819
820
821
822

    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        quant_dtype = None

823
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
bnellnm's avatar
bnellnm committed
824
825
        pytest.skip("Skip quantization test for non-quantized type")

826
827
828
    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

bnellnm's avatar
bnellnm committed
829
830
831
    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

832
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
833
834
835
836
837
        e,
        n,
        k,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
838
        per_out_ch_quant=per_act_token_quant,
839
    )
840

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
    parallel_launch(
        world_size,
        _pplx_moe,
        dp_size,
        a,
        w1,
        w2,
        score,
        topk,
        e,
        w1_s,
        w2_s,
        quant_dtype,
        per_act_token_quant,
        block_shape,
        use_internode,
    )
858
859


860
861
862
863
864
865
866
867
def _pplx_test_loop(
    pgi: ProcessGroupInfo,
    dp_size: int,
    use_internode: bool,
    use_shared_experts: bool,
    make_weights: bool,
    test_fn: Callable,
):
868
869
870
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

871
872
873
874
875
876
877
878
879
880
881
882
883
    def format_result(msg, ex=None):
        if ex is not None:
            x = str(ex)
            newx = x.strip(" \n\t")[:16]
            if len(newx) < len(x):
                newx = newx + " ..."

            prefix = "E\t"
            print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
            print(f"FAILED {msg} - {newx}\n")
        else:
            print(f"PASSED {msg}")

884
885
886
887
888
    if use_shared_experts:
        # Note: this config is only needed for the non-naive shared experts.
        new_vllm_config = copy.deepcopy(vllm_config)
        new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
        new_vllm_config.parallel_config.enable_expert_parallel = True
889
        _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
890

891
    set_random_seed(7)
892
893
894
    combos = itertools.product(
        PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
    )
895
896
897
898
899
900
901
902
903
904
905
906
907
    exceptions = []
    count = 0
    for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
        count = count + 1
        m, n, k = mnk

        if dtype == torch.float8_e4m3fn:
            use_fp8_w8a8 = True
            quant_dtype = dtype
        else:
            use_fp8_w8a8 = False
            quant_dtype = None

908
909
910
911
        test_desc = (
            f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
            f"dtype={dtype}, per_act_token={per_act_token_quant}, "
            f"block_shape={block_shape}, use_internode={use_internode}, "
912
913
            f"use_shared_experts={use_shared_experts}"
        )
914

915
916
        if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
            print(f"{test_desc} - Skip quantization test for non-quantized type.")
917
918
919
920
921
922
923
924
925
926
927
            continue

        if per_act_token_quant and block_shape is not None:
            print(f"{test_desc} - Skip illegal quantization combination.")
            continue

        a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

        args = dict()
        if make_weights:
928
            (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
929
930
931
932
933
                e,
                n,
                k,
                quant_dtype=quant_dtype,
                block_shape=block_shape,
934
                per_out_ch_quant=per_act_token_quant,
935
936
937
938
939
940
            )
            args["w1"] = w1
            args["w2"] = w2
            args["w1_s"] = w1_s
            args["w2_s"] = w2_s

941
942
943
944
945
946
947
948
        if use_shared_experts:
            args["shared_experts"] = make_shared_experts(
                n,
                k,
                in_dtype=a.dtype,
                quant_dtype=quant_dtype,
            )

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        try:
            test_fn(
                pgi=pgi,
                dp_size=dp_size,
                a=a,
                score=score,
                topk=topk,
                num_experts=e,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
                use_internode=use_internode,
                **args,
            )
            format_result(test_desc)
        except Exception as ex:
            format_result(test_desc, ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
971
972
            f"rank={pgi.rank}."
        )
973
    else:
974
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
975
976
977
978
979


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
980
@multi_gpu_test(num_gpus=2)
981
982
983
984
def test_pplx_prepare_finalize(
    world_dp_size: tuple[int, int],
    use_internode: bool,
):
985
    set_random_seed(7)
986
    world_size, dp_size = world_dp_size
987
988
989
990
991
992
993
994
995
    parallel_launch(
        world_size * dp_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        False,
        False,
        _pplx_prepare_finalize,
    )
996
997
998
999


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
1000
@pytest.mark.parametrize("use_shared_experts", [False, True])
1001
@requires_pplx
1002
@multi_gpu_test(num_gpus=2)
1003
1004
1005
def test_pplx_moe(
    world_dp_size: tuple[int, int],
    use_internode: bool,
1006
    use_shared_experts: bool,
1007
):
1008
    set_random_seed(7)
1009
    world_size, dp_size = world_dp_size
1010
1011
1012
1013
1014
1015
1016
1017
1018
    parallel_launch(
        world_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        use_shared_experts,
        True,
        _pplx_moe,
    )