test_pplx_moe.py 29.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_pplx_moe.py`.
"""
7

8
import copy
9
10
11
import itertools
import textwrap
import traceback
12
from typing import Callable, Optional, Union
13
14
15
16
17
18

import pytest
import torch

try:
    from pplx_kernels import AllToAll
19
20
21
22
23
24
25
    from pplx_kernels.nvshmem import (
        nvshmem_alloc_empty_unique_id,
        nvshmem_finalize,
        nvshmem_get_unique_id,
        nvshmem_init,
    )

26
27
28
29
    has_pplx = True
except ImportError:
    has_pplx = False

30
31
32
33
34
35
from tests.kernels.moe.modular_kernel_tools.parallel_utils import _set_vllm_config
from tests.kernels.moe.utils import (
    make_shared_experts,
    make_test_weights,
    naive_batched_moe,
)
36
from tests.kernels.quant_utils import dequant
37
from tests.kernels.utils import torch_experts
38
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
39
40
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
41
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
bnellnm's avatar
bnellnm committed
42
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
43
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
44
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
45
46
    TopKWeightAndReduceDelegate,
)
47
from vllm.platforms import current_platform
bnellnm's avatar
bnellnm committed
48
from vllm.utils import round_up
49

50
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
51
from .parallel_utils import ProcessGroupInfo, parallel_launch
52

53
54
55
56
57
requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

58
59
60
61
62
63
64
65
BATCHED_MOE_MNK_FACTORS = [
    (1, 128, 128),
    (33, 2048, 128),
    (64, 128, 2048),
    (222, 128, 128),
    (222, 2048, 1024),
]

66
PPLX_COMBOS = [
67
    # TODO(bnell): figure out why this fails, seems to be test problem
68
    # (1, 128, 128),
69
70
    (2, 128, 512),
    (3, 1024, 2048),
71
72
    (4, 128, 128),
    (32, 1024, 512),
73
    (45, 512, 2048),
74
75
76
    (64, 1024, 512),
    (222, 2048, 1024),
    (256, 1408, 2048),
77
78
79
80
]

NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
81
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


def torch_prepare(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert topk_ids.dim() == 2
    assert topk_ids.shape[0] == a.shape[0]

    num_tokens, hidden_dim = a.shape
    topk = topk_ids.shape[1]

100
    tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
101
102
103
104
105
106

    assert tokens_per_expert.numel() == num_experts

    if max_num_tokens is None:
        max_num_tokens = int(tokens_per_expert.max().item())

107
108
109
    b_a = torch.zeros(
        (num_experts, max_num_tokens, hidden_dim), dtype=a.dtype, device=a.device
    )
110
111
112
113
114
115
116

    token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)

    for token in range(num_tokens):
        for j in range(topk):
            expert_id = topk_ids[token, j]
            idx = token_counts[expert_id]
117
            b_a[expert_id, idx : idx + 1, :] = a[token, :]
118
119
120
121
122
            token_counts[expert_id] = token_counts[expert_id] + 1

    return b_a, tokens_per_expert


123
124
125
def torch_finalize(
    b_out: torch.Tensor, topk_weight: torch.Tensor, topk_ids: torch.Tensor
) -> torch.Tensor:
126
127
128
129
    num_tokens = topk_ids.shape[0]
    num_experts = b_out.shape[0]
    K = b_out.shape[-1]
    out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
130
    expert_counts = torch.zeros(num_experts, dtype=torch.int, device=b_out.device)
131
132
133
134
135
    for token in range(num_tokens):
        expert_ids = topk_ids[token]
        for i in range(expert_ids.numel()):
            expert_id = expert_ids[i]
            idx = expert_counts[expert_id]
136
137
138
139
            out[token, :] = (
                out[token, :]
                + b_out[expert_id, idx : idx + 1, :] * topk_weight[token, i]
            )
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            expert_counts[expert_id] = expert_counts[expert_id] + 1

    return out


def torch_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor:
    num_experts = w1.shape[0]
    b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
    assert b_a.dim() == 3
    num_tokens, topk = topk_ids.shape
    _, max_num_tokens, K = b_a.shape
    assert num_experts == b_a.shape[0] and w2.shape[1] == K
158
159
160
161
162
163
    out = torch.zeros(
        (num_experts, max_num_tokens, K), dtype=b_a.dtype, device=b_a.device
    )
    tmp = torch.empty(
        (max_num_tokens, w1.shape[1] // 2), dtype=b_a.dtype, device=b_a.device
    )
164
165
166
167
    for expert in range(num_experts):
        num = tokens_per_expert[expert]
        if num > 0:
            torch.ops._C.silu_and_mul(
168
169
                tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)
            )
170
171
172
173
174
            out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)

    return torch_finalize(out, topk_weight, topk_ids)


175
@pytest.mark.parametrize("m,n,k", BATCHED_MOE_MNK_FACTORS)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
):
    current_platform.seed_everything(7)

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
196
197
198
        baseline_output = torch_experts(
            a, w1, w2, topk_weight, topk_ids
        )  # only for baseline
199
        torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
200
        batched_output = naive_batched_moe(
201
202
            a, w1, w2, topk_weight, topk_ids
        )  # pick torch_experts or this
203

204
205
    torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
    torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def create_pplx_prepare_finalize(
    num_tokens: int,
    hidden_dim: int,
    topk: int,
    num_experts: int,
    rank: int,
    dp_size: int,
    world_size: int,
    in_dtype: torch.dtype,
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
    group_name: Optional[str],
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
223
224
225
        PplxPrepareAndFinalize,
        pplx_hidden_dim_scale_bytes,
    )
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
    num_local_experts = rank_chunk(num_experts, 0, world_size)

    hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
        max_num_tokens,
        hidden_dim,
        in_dtype,
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
    )

    args = dict(
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
        world_size=world_size,
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim_bytes,
        hidden_dim_scale_bytes=scale_bytes,
    )

    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

    prepare_finalize = PplxPrepareAndFinalize(
        ata,
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
        num_dispatchers=world_size // dp_size,
    )

    return prepare_finalize, ata


267
268
269
270
271
272
273
def rank_chunk(num: int, r: int, w: int) -> int:
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
    chunk = rank_chunk(t.shape[0], r, w)
274
    return t[(r * chunk) : (r + 1) * chunk]
275
276


277
278
279
def maybe_chunk_by_rank(
    t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
280
281
282
283
284
285
    if t is not None:
        return chunk_by_rank(t, r, w)
    else:
        return t


286
287
288
def chunk_scales_by_rank(
    t: Optional[torch.Tensor], r: int, w: int
) -> Optional[torch.Tensor]:
289
290
    if t is not None and t.numel() > 1:
        chunk = rank_chunk(t.shape[0], r, w)
291
        return t[(r * chunk) : (r + 1) * chunk]
292
293
294
295
    else:
        return t


296
297
298
def chunk_scales(
    t: Optional[torch.Tensor], start: int, end: int
) -> Optional[torch.Tensor]:
299
300
301
302
303
304
305
306
307
308
    if t is not None and t.numel() > 1:
        return t[start:end]
    else:
        return t


def dummy_work(a: torch.Tensor) -> torch.Tensor:
    return a * 1.1


309
310
311
312
313
314
315
def pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
316
317
318
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
319
320
    group_name: Optional[str],
) -> torch.Tensor:
321
322
323
324
325
326
327
    assert torch.cuda.current_device() == pgi.local_rank

    topk = topk_ids.shape[1]
    num_tokens, hidden_dim = a.shape
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
328

329
330
    topk_ids = topk_ids.to(dtype=torch.uint32)

331
332
333
334
335
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
        hidden_dim,
        topk,
        num_experts,
336
337
        rank,
        dp_size,
338
339
340
341
342
343
        world_size,
        a.dtype,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
344
345
    )

346
347
    assert a.shape[0] == topk_ids.shape[0]

348
349
350
351
    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)

352
353
354
355
356
357
358
359
360
    assert a_chunk.shape[0] == chunk_topk_ids.shape[0]

    out = torch.full(
        a_chunk.shape,
        torch.nan,
        dtype=a.dtype,
        device=device,
    )

361
    if quant_dtype is not None and not per_act_token_quant and block_shape is None:
362
363
364
365
366
367
        a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None

368
    b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
369
370
371
372
373
374
        a_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        num_experts,
        None,
        False,
375
        FusedMoEQuantConfig.make(
376
            quant_dtype,
377
378
379
380
381
            per_act_token_quant=per_act_token_quant,
            per_out_ch_quant=False,
            block_shape=block_shape,
            a1_scale=a1_scale,
            a2_scale=a2_scale,
382
        ),
383
384
    )

385
    b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
386
387
388
389
390
391
392

    prepare_finalize.finalize(
        out,
        b_a,
        chunk_topk_weight,
        chunk_topk_ids,
        False,
393
        weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    )

    torch.cuda.synchronize()

    ata.destroy()

    num_tokens = a_chunk.shape[0]

    return out[:num_tokens]


def _pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    score: torch.Tensor,
    topk: torch.Tensor,
    num_experts: int,
412
413
414
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
415
    use_internode: bool,
416
):
417
418
    try:
        if use_internode:
419
420
421
422
423
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
424
425
426
427
428
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
429
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
430
            group_name = cpu_group.group_name
431

432
433
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
        m, k = a.shape
434

435
        a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
436

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
        torch_output = (
            a_rep.view(m, topk, k) * topk_weight.view(m, topk, 1).to(a_rep.dtype)
        ).sum(dim=1)

        pplx_output = pplx_prepare_finalize(
            pgi,
            dp_size,
            a,
            topk_weight,
            topk_ids,
            num_experts,
            quant_dtype,
            block_shape,
            per_act_token_quant,
            group_name,
        )
453

454
455
456
        torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(
            pgi.device
        )
457

458
        torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
459
460
461
    finally:
        if use_internode:
            nvshmem_finalize()
462
463


464
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
465
466
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
467
@pytest.mark.parametrize("dtype", DTYPES)
468
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
469
470
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
471
@pytest.mark.parametrize("use_internode", [False])
472
@pytest.mark.optional
473
@requires_pplx
474
@multi_gpu_test(num_gpus=2)
475
def test_pplx_prepare_finalize_slow(
476
477
478
479
480
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
481
482
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
483
    use_internode: bool,
484
):
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        act_dtype = dtype
        quant_dtype = None

    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

500
501
502
503
    current_platform.seed_everything(7)
    m, n, k = mnk
    world_size, dp_size = world_dp_size
    device = "cuda"
504
505
506

    a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
    score = torch.randn((m, e), device=device, dtype=act_dtype)
507

508
509
510
511
512
513
514
515
516
517
518
519
520
    parallel_launch(
        world_size,
        _pplx_prepare_finalize,
        dp_size,
        a,
        score,
        topk,
        e,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        use_internode,
    )
521
522
523


def pplx_moe(
524
    group_name: Optional[str],
525
526
527
528
529
530
531
532
    rank: int,
    world_size: int,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
533
534
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
535
536
537
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    quant_dtype: Optional[torch.dtype] = None,
bnellnm's avatar
bnellnm committed
538
539
    per_act_token_quant=False,
    block_shape: Optional[list[int]] = None,
540
    use_compile: bool = False,
541
    use_cudagraphs: bool = True,
542
543
    shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
544
    num_tokens, hidden_dim = a.shape
545
546
    num_experts = w1.shape[0]
    topk = topk_ids.shape[1]
547
    max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
bnellnm's avatar
bnellnm committed
548

549
550
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
bnellnm's avatar
bnellnm committed
551
        hidden_dim,
552
553
554
555
556
        topk,
        num_experts,
        rank,
        dp_size,
        world_size,
bnellnm's avatar
bnellnm committed
557
        a.dtype,
558
559
560
561
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
bnellnm's avatar
bnellnm committed
562
    )
563
564
565
566

    topk_ids = topk_ids.to(dtype=torch.uint32)

    # Note: workers with the same dp_rank must use the exact same inputs.
567
568
569
    a_chunk = chunk_by_rank(a, rank, world_size)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
570
571

    # Chunking weights like this only works for batched format
572
573
574
575
576
577
    w1_chunk = chunk_by_rank(w1, rank, world_size)
    w2_chunk = chunk_by_rank(w2, rank, world_size)
    w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
    w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
    a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
    a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
bnellnm's avatar
bnellnm committed
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    quant_config = FusedMoEQuantConfig.make(
        quant_dtype,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant,
        w1_scale=w1_scale_chunk,
        w2_scale=w2_scale_chunk,
        a1_scale=a1_scale_chunk,
        a2_scale=a2_scale_chunk,
    )

    experts = BatchedTritonExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=prepare_finalize.num_dispatchers(),
        quant_config=quant_config,
    )

    fused_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
        shared_experts,
    )

601
602
603
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
604
    if use_compile:
605
606
607
        _fused_experts = torch.compile(
            fused_experts, backend="inductor", fullgraph=True
        )
608
609
610
        torch._dynamo.mark_dynamic(a_chunk, 0)
        torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
        torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
611
612
613
    else:
        _fused_experts = fused_experts

614
615
616
617
618
619
620
621
    out = _fused_experts(
        a_chunk,
        w1_chunk,
        w2_chunk,
        chunk_topk_weight,
        chunk_topk_ids,
        global_num_experts=num_experts,
    )
622
623

    if use_cudagraphs:
624
625
626
627
628
        if isinstance(out, tuple):
            out[0].fill_(0)
            out[1].fill_(0)
        else:
            out.fill_(0)
629
630
631
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
632
633
634
635
636
637
638
639
            out = _fused_experts(
                a_chunk,
                w1_chunk,
                w2_chunk,
                chunk_topk_weight,
                chunk_topk_ids,
                global_num_experts=num_experts,
            )
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658

        torch.cuda.synchronize()
        graph.replay()

    torch.cuda.synchronize()

    ata.destroy()

    return out


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
659
    num_experts: int,
bnellnm's avatar
bnellnm committed
660
661
    w1_s: Optional[torch.Tensor] = None,
    w2_s: Optional[torch.Tensor] = None,
662
    quant_dtype: Optional[torch.dtype] = None,
bnellnm's avatar
bnellnm committed
663
664
665
    per_act_token_quant: bool = False,
    block_shape: Optional[list[int]] = None,
    use_internode: bool = False,
666
    shared_experts: Optional[torch.nn.Module] = None,
667
):
668
669
    try:
        if use_internode:
670
671
672
673
674
            uid = (
                nvshmem_get_unique_id()
                if pgi.rank == 0
                else nvshmem_alloc_empty_unique_id()
            )
675
676
677
678
679
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
680
            cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
            group_name = cpu_group.group_name

        m, k = a.shape
        e, _, n = w2.shape

        moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)

        device = torch.device("cuda", pgi.rank)
        rank = pgi.rank
        world_size = pgi.world_size

        a = a.to(device)
        w1 = w1.to(device)
        w2 = w2.to(device)
        w1_s = w1_s.to(device) if w1_s is not None else None
        w2_s = w2_s.to(device) if w2_s is not None else None

698
        if quant_dtype is not None and not per_act_token_quant and block_shape is None:
699
700
701
702
703
704
705
706
707
            a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
            a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        else:
            a1_scale = None
            a2_scale = None

        with set_current_vllm_config(vllm_config), override_config(moe_config):
            topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)

708
709
710
711
712
            if shared_experts is not None:
                shared_output = shared_experts(a)
            else:
                shared_output = None

713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
            torch_output = torch_experts(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

            batched_output = naive_batched_moe(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

743
            pplx_outputs = pplx_moe(
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
                group_name,
                rank,
                world_size,
                dp_size,
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
760
                shared_experts=shared_experts,
761
762
            )

763
764
765
766
767
768
769
770
771
772
        if shared_experts is None:
            pplx_shared_output = None
            pplx_output = pplx_outputs
            assert isinstance(pplx_output, torch.Tensor)
        else:
            pplx_shared_output, pplx_output = pplx_outputs

        if shared_output is not None:
            assert pplx_shared_output is not None
            chunked_shared_output = chunk_by_rank(
773
774
                shared_output, pgi.rank, pgi.world_size
            ).to(pplx_shared_output.device)
775
776
777
        else:
            chunked_shared_output = None

778
        chunked_batch_output = chunk_by_rank(
779
780
            batched_output, pgi.rank, pgi.world_size
        ).to(pplx_output.device)
781

782
        torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
783

784
785
786
        torch.testing.assert_close(
            pplx_output, chunked_batch_output, atol=3e-2, rtol=3e-2
        )
787
788
789
790

        if shared_experts is not None:
            assert chunked_shared_output is not None
            assert pplx_shared_output is not None
791
792
793
            torch.testing.assert_close(
                pplx_shared_output, chunked_shared_output, atol=3e-2, rtol=3e-2
            )
794

795
796
797
798
799
800
    finally:
        if use_internode:
            nvshmem_finalize()


@pytest.mark.parametrize("mnk", PPLX_COMBOS)
801
802
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
803
@pytest.mark.parametrize("dtype", DTYPES)
804
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
bnellnm's avatar
bnellnm committed
805
806
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
807
@pytest.mark.parametrize("use_internode", [False])
808
@pytest.mark.optional
809
@requires_pplx
810
@multi_gpu_test(num_gpus=2)
811
def test_pplx_moe_slow(
812
813
814
815
816
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
bnellnm's avatar
bnellnm committed
817
818
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
819
    use_internode: bool,
820
821
822
823
):
    current_platform.seed_everything(7)
    m, n, k = mnk
    world_size, dp_size = world_dp_size
bnellnm's avatar
bnellnm committed
824
825
826
827
828
829
830
831

    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        quant_dtype = None

832
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
bnellnm's avatar
bnellnm committed
833
834
        pytest.skip("Skip quantization test for non-quantized type")

835
836
837
    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

bnellnm's avatar
bnellnm committed
838
839
840
    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

841
    (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
842
843
844
845
846
        e,
        n,
        k,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
847
        per_out_ch_quant=per_act_token_quant,
848
    )
849

850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    parallel_launch(
        world_size,
        _pplx_moe,
        dp_size,
        a,
        w1,
        w2,
        score,
        topk,
        e,
        w1_s,
        w2_s,
        quant_dtype,
        per_act_token_quant,
        block_shape,
        use_internode,
    )
867
868


869
870
871
872
873
874
875
876
def _pplx_test_loop(
    pgi: ProcessGroupInfo,
    dp_size: int,
    use_internode: bool,
    use_shared_experts: bool,
    make_weights: bool,
    test_fn: Callable,
):
877
878
879
880
881
882
883
884
885
886
887
888
889
    def format_result(msg, ex=None):
        if ex is not None:
            x = str(ex)
            newx = x.strip(" \n\t")[:16]
            if len(newx) < len(x):
                newx = newx + " ..."

            prefix = "E\t"
            print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
            print(f"FAILED {msg} - {newx}\n")
        else:
            print(f"PASSED {msg}")

890
891
892
893
894
    if use_shared_experts:
        # Note: this config is only needed for the non-naive shared experts.
        new_vllm_config = copy.deepcopy(vllm_config)
        new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
        new_vllm_config.parallel_config.enable_expert_parallel = True
895
        _set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank, pgi.local_rank)
896

897
    current_platform.seed_everything(7)
898
899
900
    combos = itertools.product(
        PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES, [False, True], [None, [128, 128]]
    )
901
902
903
904
905
906
907
908
909
910
911
912
913
    exceptions = []
    count = 0
    for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
        count = count + 1
        m, n, k = mnk

        if dtype == torch.float8_e4m3fn:
            use_fp8_w8a8 = True
            quant_dtype = dtype
        else:
            use_fp8_w8a8 = False
            quant_dtype = None

914
915
916
917
        test_desc = (
            f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
            f"dtype={dtype}, per_act_token={per_act_token_quant}, "
            f"block_shape={block_shape}, use_internode={use_internode}, "
918
919
            f"use_shared_experts={use_shared_experts}"
        )
920

921
922
        if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
            print(f"{test_desc} - Skip quantization test for non-quantized type.")
923
924
925
926
927
928
929
930
931
932
933
            continue

        if per_act_token_quant and block_shape is not None:
            print(f"{test_desc} - Skip illegal quantization combination.")
            continue

        a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

        args = dict()
        if make_weights:
934
            (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
935
936
937
938
939
                e,
                n,
                k,
                quant_dtype=quant_dtype,
                block_shape=block_shape,
940
                per_out_ch_quant=per_act_token_quant,
941
942
943
944
945
946
            )
            args["w1"] = w1
            args["w2"] = w2
            args["w1_s"] = w1_s
            args["w2_s"] = w2_s

947
948
949
950
951
952
953
954
        if use_shared_experts:
            args["shared_experts"] = make_shared_experts(
                n,
                k,
                in_dtype=a.dtype,
                quant_dtype=quant_dtype,
            )

955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
        try:
            test_fn(
                pgi=pgi,
                dp_size=dp_size,
                a=a,
                score=score,
                topk=topk,
                num_experts=e,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
                use_internode=use_internode,
                **args,
            )
            format_result(test_desc)
        except Exception as ex:
            format_result(test_desc, ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
977
978
            f"rank={pgi.rank}."
        )
979
    else:
980
        print(f"{count} of {count} tests passed in child process, rank={pgi.rank}.")
981
982
983
984
985


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
986
@multi_gpu_test(num_gpus=2)
987
988
989
990
991
992
def test_pplx_prepare_finalize(
    world_dp_size: tuple[int, int],
    use_internode: bool,
):
    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
993
994
995
996
997
998
999
1000
1001
    parallel_launch(
        world_size * dp_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        False,
        False,
        _pplx_prepare_finalize,
    )
1002
1003
1004
1005


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
1006
@pytest.mark.parametrize("use_shared_experts", [False, True])
1007
@requires_pplx
1008
@multi_gpu_test(num_gpus=2)
1009
1010
1011
def test_pplx_moe(
    world_dp_size: tuple[int, int],
    use_internode: bool,
1012
    use_shared_experts: bool,
1013
1014
1015
):
    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
1016
1017
1018
1019
1020
1021
1022
1023
1024
    parallel_launch(
        world_size,
        _pplx_test_loop,
        dp_size,
        use_internode,
        use_shared_experts,
        True,
        _pplx_moe,
    )