test_pplx_moe.py 28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_pplx_moe.py`.
"""
7
8
9
10
import itertools
import textwrap
import traceback
from typing import Callable, Optional
11
12
13
14
15
16
17
18
19
20
21
22
23

import pytest
import torch

try:
    from pplx_kernels import AllToAll
    from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
                                      nvshmem_finalize, nvshmem_get_unique_id,
                                      nvshmem_init)
    has_pplx = True
except ImportError:
    has_pplx = False

bnellnm's avatar
bnellnm committed
24
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
25
from tests.kernels.quant_utils import dequant
26
from tests.kernels.utils import torch_experts
27
from vllm.config import VllmConfig, set_current_vllm_config
bnellnm's avatar
bnellnm committed
28
29
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
30
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31
    BatchedTritonExperts)
bnellnm's avatar
bnellnm committed
32
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
33
34
from vllm.model_executor.layers.fused_moe.modular_kernel import (
    FusedMoEModularKernel)
35
36
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
    TopKWeightAndReduceDelegate)
37
from vllm.platforms import current_platform
bnellnm's avatar
bnellnm committed
38
from vllm.utils import round_up
39

bnellnm's avatar
bnellnm committed
40
from .parallel_utils import ProcessGroupInfo, parallel_launch
41

42
43
44
45
46
requires_pplx = pytest.mark.skipif(
    not has_pplx,
    reason="Requires PPLX kernels",
)

47
48
49
PPLX_COMBOS = [
    # TODO: figure out why this fails, seems to be test problem
    #(1, 128, 128),
50
51
    (2, 128, 512),
    (3, 1024, 2048),
52
53
    (4, 128, 128),
    (32, 1024, 512),
54
    (45, 512, 2048),
55
56
57
    (64, 1024, 512),
    (222, 2048, 1024),
    (256, 1408, 2048),
58
59
60
61
]

NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
62
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192


def torch_prepare(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
    max_num_tokens: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert topk_ids.dim() == 2
    assert topk_ids.shape[0] == a.shape[0]

    num_tokens, hidden_dim = a.shape
    topk = topk_ids.shape[1]

    tokens_per_expert = torch.bincount(topk_ids.view(-1),
                                       minlength=num_experts)

    assert tokens_per_expert.numel() == num_experts

    if max_num_tokens is None:
        max_num_tokens = int(tokens_per_expert.max().item())

    b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
                      dtype=a.dtype,
                      device=a.device)

    token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)

    for token in range(num_tokens):
        for j in range(topk):
            expert_id = topk_ids[token, j]
            idx = token_counts[expert_id]
            b_a[expert_id, idx:idx + 1, :] = a[token, :]
            token_counts[expert_id] = token_counts[expert_id] + 1

    return b_a, tokens_per_expert


def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor,
                   topk_ids: torch.Tensor) -> torch.Tensor:
    num_tokens = topk_ids.shape[0]
    num_experts = b_out.shape[0]
    K = b_out.shape[-1]
    out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
    expert_counts = torch.zeros(num_experts,
                                dtype=torch.int,
                                device=b_out.device)
    for token in range(num_tokens):
        expert_ids = topk_ids[token]
        for i in range(expert_ids.numel()):
            expert_id = expert_ids[i]
            idx = expert_counts[expert_id]
            out[token, :] = out[token, :] + b_out[expert_id, idx:idx +
                                                  1, :] * topk_weight[token, i]
            expert_counts[expert_id] = expert_counts[expert_id] + 1

    return out


def torch_batched_moe(
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
) -> torch.Tensor:
    num_experts = w1.shape[0]
    b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts)
    assert b_a.dim() == 3
    num_tokens, topk = topk_ids.shape
    _, max_num_tokens, K = b_a.shape
    assert num_experts == b_a.shape[0] and w2.shape[1] == K
    out = torch.zeros((num_experts, max_num_tokens, K),
                      dtype=b_a.dtype,
                      device=b_a.device)
    tmp = torch.empty((max_num_tokens, w1.shape[1] // 2),
                      dtype=b_a.dtype,
                      device=b_a.device)
    for expert in range(num_experts):
        num = tokens_per_expert[expert]
        if num > 0:
            torch.ops._C.silu_and_mul(
                tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1))
            out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1)

    return torch_finalize(out, topk_weight, topk_ids)


@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_moe_batched_experts(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
):
    current_platform.seed_everything(7)

    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
    score = torch.randn((m, e), device="cuda", dtype=dtype)

    with set_current_vllm_config(vllm_config):
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
178
179
        baseline_output = torch_experts(a, w1, w2, topk_weight,
                                        topk_ids)  # only for baseline
180
        torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
181
182
        batched_output = naive_batched_moe(
            a, w1, w2, topk_weight, topk_ids)  # pick torch_experts or this
183
184
185
186
187
188
189
190
191
192
193

    torch.testing.assert_close(baseline_output,
                               torch_output,
                               atol=2e-2,
                               rtol=0)
    torch.testing.assert_close(baseline_output,
                               batched_output,
                               atol=2e-2,
                               rtol=0)


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def create_pplx_prepare_finalize(
    num_tokens: int,
    hidden_dim: int,
    topk: int,
    num_experts: int,
    rank: int,
    dp_size: int,
    world_size: int,
    in_dtype: torch.dtype,
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
    group_name: Optional[str],
):
    from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
        PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)

    max_num_tokens = max(rank_chunk(num_tokens, 0, world_size), 1)
    num_local_experts = rank_chunk(num_experts, 0, world_size)

    hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
        max_num_tokens,
        hidden_dim,
        in_dtype,
        quant_dtype,
        per_act_token_quant=per_act_token_quant,
        block_shape=block_shape,
    )

    args = dict(
        max_num_tokens=max_num_tokens,
        num_experts=num_experts,
        experts_per_token=topk,
        rank=rank,
        world_size=world_size,
        dp_size=dp_size,
        hidden_dim=hidden_dim,
        hidden_dim_bytes=hidden_dim_bytes,
        hidden_dim_scale_bytes=scale_bytes,
    )

    if group_name is None:
        ata = AllToAll.internode(**args)
    else:
        args["group_name"] = group_name
        ata = AllToAll.intranode(**args)

    prepare_finalize = PplxPrepareAndFinalize(
        ata,
        max_num_tokens=max_num_tokens,
        num_local_experts=num_local_experts,
        num_dispatchers=world_size // dp_size,
    )

    return prepare_finalize, ata


251
252
253
254
255
256
257
258
259
260
def rank_chunk(num: int, r: int, w: int) -> int:
    rem = num % w
    return (num // w) + (1 if r < rem else 0)


def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
    chunk = rank_chunk(t.shape[0], r, w)
    return t[(r * chunk):(r + 1) * chunk]


261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def maybe_chunk_by_rank(t: Optional[torch.Tensor], r: int,
                        w: int) -> Optional[torch.Tensor]:
    if t is not None:
        return chunk_by_rank(t, r, w)
    else:
        return t


def chunk_scales_by_rank(t: Optional[torch.Tensor], r: int,
                         w: int) -> Optional[torch.Tensor]:
    if t is not None and t.numel() > 1:
        chunk = rank_chunk(t.shape[0], r, w)
        return t[(r * chunk):(r + 1) * chunk]
    else:
        return t


def chunk_scales(t: Optional[torch.Tensor], start: int,
                 end: int) -> Optional[torch.Tensor]:
    if t is not None and t.numel() > 1:
        return t[start:end]
    else:
        return t


def dummy_work(a: torch.Tensor) -> torch.Tensor:
    return a * 1.1


290
291
292
293
294
295
296
def pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
    num_experts: int,
297
298
299
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
300
301
    group_name: Optional[str],
) -> torch.Tensor:
302
303
304
305
306
307
308
    assert torch.cuda.current_device() == pgi.local_rank

    topk = topk_ids.shape[1]
    num_tokens, hidden_dim = a.shape
    device = pgi.device
    rank = pgi.rank
    world_size = pgi.world_size
309

310
311
    topk_ids = topk_ids.to(dtype=torch.uint32)

312
313
314
315
316
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
        hidden_dim,
        topk,
        num_experts,
317
318
        rank,
        dp_size,
319
320
321
322
323
324
        world_size,
        a.dtype,
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
325
326
    )

327
328
    assert a.shape[0] == topk_ids.shape[0]

329
330
331
332
    a_chunk = chunk_by_rank(a, rank, world_size).to(device)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    assert a_chunk.shape[0] == chunk_topk_ids.shape[0]

    out = torch.full(
        a_chunk.shape,
        torch.nan,
        dtype=a.dtype,
        device=device,
    )

    if (quant_dtype is not None and not per_act_token_quant
            and block_shape is None):
        a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
    else:
        a1_scale = None
        a2_scale = None

350
    b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
351
        a_chunk,
352
353
        a1_scale,
        a2_scale,
354
355
356
357
358
        chunk_topk_weight,
        chunk_topk_ids,
        num_experts,
        None,
        False,
359
360
361
362
363
364
        FusedMoEQuantConfig(
            quant_dtype,
            per_act_token_quant,
            False,
            block_shape,
        ),
365
366
    )

367
368
    b_a = dummy_work(
        dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
369
370
371
372
373
374
375

    prepare_finalize.finalize(
        out,
        b_a,
        chunk_topk_weight,
        chunk_topk_ids,
        False,
376
        weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    )

    torch.cuda.synchronize()

    ata.destroy()

    num_tokens = a_chunk.shape[0]

    return out[:num_tokens]


def _pplx_prepare_finalize(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    score: torch.Tensor,
    topk: torch.Tensor,
    num_experts: int,
395
396
397
    quant_dtype: Optional[torch.dtype],
    block_shape: Optional[list[int]],
    per_act_token_quant: bool,
398
    use_internode: bool,
399
):
400
401
402
403
404
405
406
407
408
409
410
411
    try:
        if use_internode:
            uid = nvshmem_get_unique_id(
            ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
            cpu_group = torch.distributed.new_group(group_ranks,
                                                    backend="gloo")
            group_name = cpu_group.group_name
412

413
414
        topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
        m, k = a.shape
415

416
        a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
417

418
419
420
        torch_output = (a_rep.view(m, topk, k) *
                        topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(
                            dim=1)
421

422
423
424
425
        pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight,
                                            topk_ids, num_experts, quant_dtype,
                                            block_shape, per_act_token_quant,
                                            group_name)
426

427
428
        torch_output = chunk_by_rank(torch_output, pgi.rank,
                                     pgi.world_size).to(pgi.device)
429

430
431
432
433
434
435
436
        torch.testing.assert_close(pplx_output,
                                   torch_output,
                                   atol=3e-2,
                                   rtol=3e-2)
    finally:
        if use_internode:
            nvshmem_finalize()
437
438


439
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
440
441
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
442
@pytest.mark.parametrize("dtype", DTYPES)
443
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
444
445
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
446
@pytest.mark.parametrize("use_internode", [False])
447
@pytest.mark.optional
448
@requires_pplx
449
def test_pplx_prepare_finalize_slow(
450
451
452
453
454
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
455
456
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
457
    use_internode: bool,
458
):
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        act_dtype = torch.bfloat16
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        act_dtype = dtype
        quant_dtype = None

    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
        pytest.skip("Skip quantization test for non-quantized type")

    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

474
475
476
477
    current_platform.seed_everything(7)
    m, n, k = mnk
    world_size, dp_size = world_dp_size
    device = "cuda"
478
479
480

    a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
    score = torch.randn((m, e), device=device, dtype=act_dtype)
481
482

    parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
483
484
                    topk, e, quant_dtype, block_shape, per_act_token_quant,
                    use_internode)
485
486
487


def pplx_moe(
488
    group_name: Optional[str],
489
490
491
492
493
494
495
496
    rank: int,
    world_size: int,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    topk_weight: torch.Tensor,
    topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
497
498
    w1_scale: Optional[torch.Tensor] = None,
    w2_scale: Optional[torch.Tensor] = None,
499
500
501
    a1_scale: Optional[torch.Tensor] = None,
    a2_scale: Optional[torch.Tensor] = None,
    quant_dtype: Optional[torch.dtype] = None,
bnellnm's avatar
bnellnm committed
502
503
    per_act_token_quant=False,
    block_shape: Optional[list[int]] = None,
504
    use_compile: bool = False,
505
506
507
    use_cudagraphs: bool = True,
) -> torch.Tensor:

508
    num_tokens, hidden_dim = a.shape
509
510
    num_experts = w1.shape[0]
    topk = topk_ids.shape[1]
511
    max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 16)
bnellnm's avatar
bnellnm committed
512

513
514
    prepare_finalize, ata = create_pplx_prepare_finalize(
        num_tokens,
bnellnm's avatar
bnellnm committed
515
        hidden_dim,
516
517
518
519
520
        topk,
        num_experts,
        rank,
        dp_size,
        world_size,
bnellnm's avatar
bnellnm committed
521
        a.dtype,
522
523
524
525
        quant_dtype,
        block_shape,
        per_act_token_quant,
        group_name,
bnellnm's avatar
bnellnm committed
526
    )
527
528
529

    topk_ids = topk_ids.to(dtype=torch.uint32)

530
531
532
533
534
535
    experts = BatchedTritonExperts(
        max_num_tokens=max_num_tokens,
        num_dispatchers=prepare_finalize.num_dispatchers(),
        use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant,
536
537
538
539
540
541
542
543
    )

    fused_experts = FusedMoEModularKernel(
        prepare_finalize,
        experts,
    )

    # Note: workers with the same dp_rank must use the exact same inputs.
544
545
546
    a_chunk = chunk_by_rank(a, rank, world_size)
    chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
    chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size)
547
548

    # Chunking weights like this only works for batched format
549
550
551
552
553
554
    w1_chunk = chunk_by_rank(w1, rank, world_size)
    w2_chunk = chunk_by_rank(w2, rank, world_size)
    w1_scale_chunk = maybe_chunk_by_rank(w1_scale, rank, world_size)
    w2_scale_chunk = maybe_chunk_by_rank(w2_scale, rank, world_size)
    a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
    a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
bnellnm's avatar
bnellnm committed
555

556
557
558
    # Note: for now use_compile will error out if the problem size is
    # large enough to trigger chunking. I'm leaving the flag and
    # setup code in case we are able to revisit this later.
559
560
561
562
    if use_compile:
        _fused_experts = torch.compile(fused_experts,
                                       backend='inductor',
                                       fullgraph=True)
563
564
565
        torch._dynamo.mark_dynamic(a_chunk, 0)
        torch._dynamo.mark_dynamic(chunk_topk_weight, 0)
        torch._dynamo.mark_dynamic(chunk_topk_ids, 0)
566
567
568
569
570
571
572
573
    else:
        _fused_experts = fused_experts

    out = _fused_experts(a_chunk,
                         w1_chunk,
                         w2_chunk,
                         chunk_topk_weight,
                         chunk_topk_ids,
bnellnm's avatar
bnellnm committed
574
575
                         w1_scale=w1_scale_chunk,
                         w2_scale=w2_scale_chunk,
576
577
                         a1_scale=a1_scale_chunk,
                         a2_scale=a2_scale_chunk,
578
579
580
581
582
583
584
585
586
587
588
589
                         global_num_experts=num_experts)

    if use_cudagraphs:
        out.fill_(0)
        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
            out = _fused_experts(a_chunk,
                                 w1_chunk,
                                 w2_chunk,
                                 chunk_topk_weight,
                                 chunk_topk_ids,
bnellnm's avatar
bnellnm committed
590
591
                                 w1_scale=w1_scale_chunk,
                                 w2_scale=w2_scale_chunk,
592
593
                                 a1_scale=a1_scale_chunk,
                                 a2_scale=a2_scale_chunk,
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
                                 global_num_experts=num_experts)

        torch.cuda.synchronize()
        graph.replay()

    torch.cuda.synchronize()

    ata.destroy()

    return out


def _pplx_moe(
    pgi: ProcessGroupInfo,
    dp_size: int,
    a: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    score: torch.Tensor,
    topk: int,
614
    num_experts: int,
bnellnm's avatar
bnellnm committed
615
616
    w1_s: Optional[torch.Tensor] = None,
    w2_s: Optional[torch.Tensor] = None,
617
    quant_dtype: Optional[torch.dtype] = None,
bnellnm's avatar
bnellnm committed
618
619
620
    per_act_token_quant: bool = False,
    block_shape: Optional[list[int]] = None,
    use_internode: bool = False,
621
):
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    try:
        if use_internode:
            uid = nvshmem_get_unique_id(
            ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
            torch.distributed.broadcast(uid, src=0)
            nvshmem_init(uid, pgi.rank, pgi.world_size)
            group_name = None
        else:
            group_ranks = list(range(pgi.world_size))
            cpu_group = torch.distributed.new_group(group_ranks,
                                                    backend="gloo")
            group_name = cpu_group.group_name

        m, k = a.shape
        e, _, n = w2.shape

        moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)

        device = torch.device("cuda", pgi.rank)
        rank = pgi.rank
        world_size = pgi.world_size

        a = a.to(device)
        w1 = w1.to(device)
        w2 = w2.to(device)
        w1_s = w1_s.to(device) if w1_s is not None else None
        w2_s = w2_s.to(device) if w2_s is not None else None

        if (quant_dtype is not None and not per_act_token_quant
                and block_shape is None):
            a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
            a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
        else:
            a1_scale = None
            a2_scale = None

        with set_current_vllm_config(vllm_config), override_config(moe_config):
            topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)

            torch_output = torch_experts(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

            batched_output = naive_batched_moe(
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

            pplx_output = pplx_moe(
                group_name,
                rank,
                world_size,
                dp_size,
                a,
                w1,
                w2,
                topk_weight,
                topk_ids,
                w1_scale=w1_s,
                w2_scale=w2_s,
                a1_scale=a1_scale,
                a2_scale=a2_scale,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
            )

        chunked_batch_output = chunk_by_rank(
            batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)

        torch.testing.assert_close(batched_output,
                                   torch_output,
                                   atol=3e-2,
                                   rtol=3e-2)

        torch.testing.assert_close(pplx_output,
                                   chunked_batch_output,
                                   atol=3e-2,
                                   rtol=3e-2)
    finally:
        if use_internode:
            nvshmem_finalize()


@pytest.mark.parametrize("mnk", PPLX_COMBOS)
728
729
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
730
@pytest.mark.parametrize("dtype", DTYPES)
731
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
bnellnm's avatar
bnellnm committed
732
733
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
734
@pytest.mark.parametrize("use_internode", [False])
735
@pytest.mark.optional
736
@requires_pplx
737
def test_pplx_moe_slow(
738
739
740
741
742
    mnk: tuple[int, int, int],
    e: int,
    topk: int,
    dtype: torch.dtype,
    world_dp_size: tuple[int, int],
bnellnm's avatar
bnellnm committed
743
744
    per_act_token_quant: bool,
    block_shape: Optional[list[int]],
745
    use_internode: bool,
746
747
748
749
):
    current_platform.seed_everything(7)
    m, n, k = mnk
    world_size, dp_size = world_dp_size
bnellnm's avatar
bnellnm committed
750
751
752
753
754
755
756
757

    if dtype == torch.float8_e4m3fn:
        use_fp8_w8a8 = True
        quant_dtype = dtype
    else:
        use_fp8_w8a8 = False
        quant_dtype = None

758
    if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
bnellnm's avatar
bnellnm committed
759
760
        pytest.skip("Skip quantization test for non-quantized type")

761
762
763
    if per_act_token_quant and block_shape is not None:
        pytest.skip("Skip illegal quantization combination")

bnellnm's avatar
bnellnm committed
764
765
766
    a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
    score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

767
768
769
770
771
772
773
774
    _, w1, w1_s, _, w2, w2_s = make_test_weights(
        e,
        n,
        k,
        quant_dtype=quant_dtype,
        block_shape=block_shape,
        per_act_token_quant=per_act_token_quant,
    )
775

776
    parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
bnellnm's avatar
bnellnm committed
777
                    w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
778
                    use_internode)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897


def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
                    make_weights: bool, test_fn: Callable):

    def format_result(msg, ex=None):
        if ex is not None:
            x = str(ex)
            newx = x.strip(" \n\t")[:16]
            if len(newx) < len(x):
                newx = newx + " ..."

            prefix = "E\t"
            print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
            print(f"FAILED {msg} - {newx}\n")
        else:
            print(f"PASSED {msg}")

    current_platform.seed_everything(7)
    combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
                               [False, True], [None, [128, 128]])
    exceptions = []
    count = 0
    for mnk, e, topk, dtype, per_act_token_quant, block_shape in combos:
        count = count + 1
        m, n, k = mnk

        if dtype == torch.float8_e4m3fn:
            use_fp8_w8a8 = True
            quant_dtype = dtype
        else:
            use_fp8_w8a8 = False
            quant_dtype = None

        test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
                     f"dtype={dtype}, per_act_token={per_act_token_quant}, "
                     f"block_shape={block_shape}")

        if not use_fp8_w8a8 and (per_act_token_quant
                                 or block_shape is not None):
            print(
                f"{test_desc} - Skip quantization test for non-quantized type."
            )
            continue

        if per_act_token_quant and block_shape is not None:
            print(f"{test_desc} - Skip illegal quantization combination.")
            continue

        a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
        score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)

        args = dict()
        if make_weights:
            _, w1, w1_s, _, w2, w2_s = make_test_weights(
                e,
                n,
                k,
                quant_dtype=quant_dtype,
                block_shape=block_shape,
                per_act_token_quant=per_act_token_quant,
            )
            args["w1"] = w1
            args["w2"] = w2
            args["w1_s"] = w1_s
            args["w2_s"] = w2_s

        try:
            test_fn(
                pgi=pgi,
                dp_size=dp_size,
                a=a,
                score=score,
                topk=topk,
                num_experts=e,
                quant_dtype=quant_dtype,
                per_act_token_quant=per_act_token_quant,
                block_shape=block_shape,
                use_internode=use_internode,
                **args,
            )
            format_result(test_desc)
        except Exception as ex:
            format_result(test_desc, ex)
            exceptions.append(ex)

    if len(exceptions) > 0:
        raise RuntimeError(
            f"{len(exceptions)} of {count} tests failed in child process, "
            f"rank={pgi.rank}.")
    else:
        print(f"{count} of {count} tests passed in child process, "
              f"rank={pgi.rank}.")


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_prepare_finalize(
    world_dp_size: tuple[int, int],
    use_internode: bool,
):
    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
    parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
                    use_internode, False, _pplx_prepare_finalize)


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_moe(
    world_dp_size: tuple[int, int],
    use_internode: bool,
):
    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
    parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
                    _pplx_moe)