test_deepep_deepgemm_moe.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
bnellnm's avatar
bnellnm committed
4
Test DeepEP + DeepGEMM integration
5
6
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
7
8
9
"""

import dataclasses
10
from contextlib import contextmanager
11
12
13
14
15
16
17

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
18
from vllm.forward_context import set_forward_context
19
from vllm.model_executor.layers.fused_moe.config import (
20
21
22
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
23
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
24
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
25
26
27
28
29
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
30
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
31
from vllm.utils.torch_utils import set_random_seed
32
from vllm.v1.worker.workspace import init_workspace_manager
33

34
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
35
from .parallel_utils import ProcessGroupInfo, parallel_launch
36
from .utils import make_dummy_moe_config, make_test_weights
37

38
if has_deep_ep():
39
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
40
41
        DeepEPHTPrepareAndFinalize,
    )
42
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
43
44
        DeepEPLLPrepareAndFinalize,
    )
45

bnellnm's avatar
bnellnm committed
46
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
47

48
if has_deep_gemm():
49
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
50
51
52
        BatchedDeepGemmExperts,
    )
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
53
54

requires_deep_ep = pytest.mark.skipif(
55
    not has_deep_ep(),
56
57
58
59
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
60
    not is_deep_gemm_supported(),
61
62
63
64
65
66
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@contextmanager
def with_dp_metadata(M: int, world_size: int):
    num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)

    vllm_config = VllmConfig()
    vllm_config.parallel_config.data_parallel_size = world_size
    vllm_config.parallel_config.enable_expert_parallel = True

    with set_forward_context(
        None,
        vllm_config,
        num_tokens=M,
        num_tokens_across_dp=num_tokens_across_dp,
    ):
        yield


84
85
def next_power_of_2(x):
    import math
86

87
88
    if x == 0:
        return 1
89
    return 2 ** math.ceil(math.log2(x))
90
91


92
93
94
95
96
97
98
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
99
    Return weights w1q, w2q, w1_scale, w2_scale
100
    """
101
102
103
    (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
    )
bnellnm's avatar
bnellnm committed
104
    return w1q, w2q, w1_scale, w2_scale
105
106
107
108
109
110
111
112
113


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
114
    per_act_token_quant: bool
115
    block_size: list[int]
116
117
    # configs for testing low-latency kernels
    low_latency: bool
118
    use_fp8_dispatch: bool | None = False
119
120
121
122
123


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
124
    rank_token_scales: torch.Tensor | None
125
126
127
128
129
130
131
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":
        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
132
        topk, m, k = (config.topk, config.m, config.k)
133
134
135
136

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

137
138
139
        rank_tokens = (
            torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        )
140
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
141
        rank_token_scales = None
142
143
144
145
146

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
147
148
            device=torch.cuda.current_device(),
        ).to(dtype=torch.int64)
149

150
151
152
        topk_weights = torch.randn(
            topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
        )
153

154
155
156
157
158
159
160
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk_ids,
            topk_weights=topk_weights,
            config=config,
        )
161
162


163
def make_ll_modular_kernel(
164
165
166
167
168
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    max_tokens_per_rank: int,
    dp_size: int,
    hidden_size: int,
169
    q_dtype: torch.dtype | None,
170
171
172
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
173
174
175
176
177
178
179
180
181
182
183
184
    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
185
186
            use_fp8_dispatch=test_config.use_fp8_dispatch,
        ),
187
        q_dtype=q_dtype,
188
189
        block_shape=test_config.block_size,
    )
190

bnellnm's avatar
bnellnm committed
191
192
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
193
        num_dispatchers=pgi.world_size // dp_size,
194
        quant_config=quant_config,
195
        moe_config=make_dummy_moe_config(),
196
    )
197
198
199
200
201
    return FusedMoEModularKernel(
        prepare_finalize=a2a,
        fused_experts=fused_experts,
        inplace=False,
    )
202
203


204
def make_ht_modular_kernel(
205
206
207
208
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
209
    q_dtype: torch.dtype | None,
210
211
212
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
213
214
    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
215
216
217
218
219
220
221
222

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
223
224
        block_shape=test_config.block_size,
    )
225

226
227
228
229
    fused_experts = DeepGemmExperts(
        moe_config=make_dummy_moe_config(),
        quant_config=quant_config,
    )
230
231
232
233
234
    return FusedMoEModularKernel(
        prepare_finalize=a2a,
        fused_experts=fused_experts,
        inplace=False,
    )
235
236


237
def make_modular_kernel(
238
239
240
241
242
243
244
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
    test_tensors: TestTensors,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
245
246
247
248
249
250
    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
251
        max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
252
253
        hidden_size = test_tensors.rank_tokens.size(-1)

254
255
256
257
258
259
260
261
262
263
        mk = make_ll_modular_kernel(
            pg=pg,
            pgi=pgi,
            max_tokens_per_rank=max_tokens_per_rank,
            dp_size=dp_size,
            hidden_size=hidden_size,
            q_dtype=q_dtype,
            test_config=test_config,
            quant_config=quant_config,
        )
264
    else:
265
266
267
268
269
270
271
272
273
        mk = make_ht_modular_kernel(
            pg,
            pgi,
            dp_size,
            num_local_experts,
            q_dtype,
            test_config,
            quant_config=quant_config,
        )
274
275
276
277

    return mk


278
279
280
281
282
283
284
def deepep_deepgemm_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
285
286
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
287
) -> torch.Tensor:
288
289
    test_config = test_tensors.config
    num_experts = test_config.num_experts
290
291
292
293
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
294
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
295
296
297
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
298
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
299

300
301
302
303
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        # Low-Latency kernels can't dispatch scales.
304
        a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
305
306
307
        block_shape=test_config.block_size,
    )

308
309
    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
310
311
312
313
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
314
        test_tensors=test_tensors,
315
316
        quant_config=quant_config,
    )
317

318
319
320
321
322
323
324
325
326
327
328
329
330
331
    with with_dp_metadata(
        M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
    ):
        out = mk.forward(
            hidden_states=test_tensors.rank_tokens,
            w1=w1,
            w2=w2,
            topk_weights=test_tensors.topk_weights,
            topk_ids=test_tensors.topk,
            activation="silu",
            global_num_experts=num_experts,
            expert_map=build_expert_map(),
            apply_router_weight_on_input=False,
        )
332
    return out
333
334


335
336
337
338
339
340
341
342
343
344
345
def triton_impl(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    a1_scale: torch.Tensor,
    block_shape: list[int],
):
346
347
348
349
350
351
352
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
    )

353
354
355
356
357
358
359
    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
360
        quant_config=quant_config,
361
    )
362
363


364
def _test_deepep_deepgemm_moe(
365
366
367
368
369
370
371
372
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
373
374
375
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

376
    set_random_seed(pgi.rank)
377
378
379
380
381
382
383
384

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
385
    block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
386
387
388

    with set_current_vllm_config(VllmConfig()):
        # Reference
389
390
391
392
393
394
395
396
397
398
399
        triton_moe = triton_impl(
            a=test_tensors.rank_tokens,
            topk_ids=test_tensors.topk,
            topk_weights=test_tensors.topk_weights,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=test_tensors.rank_token_scales,
            block_shape=block_shape,
        )
400
401
402
403
404
405
406
407
408
409

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

410
        deepep_moe = deepep_deepgemm_moe_impl(
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

441
442
443
TOPKS = [2, 6]
NUM_EXPERTS = [32]

444
445

@pytest.mark.parametrize("mnk", MNKs)
446
447
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
448
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
449
@multi_gpu_test(num_gpus=2)
450
451
@requires_deep_ep
@requires_deep_gemm
452
453
454
455
456
def test_ht_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
457
    disable_deepgemm_ue8m0,
458
    workspace_init,
459
):
460
461
462
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
463
464

    m, n, k = mnk
465
    set_random_seed(7)
466
467
468
469

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

470
    block_m = get_mk_alignment_for_contiguous_layout()[0]
471
472
    block_size = [block_m, block_m]

473
    world_size, dp_size = world_dp_size
474
475
476
477
478
479
480
481
482
483
484
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
        per_act_token_quant=False,
        block_size=block_size,
        low_latency=False,
        use_fp8_dispatch=None,
    )
485
486

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
487
488
        num_experts, n, k, block_size
    )
489

490
491
492
493
494
495
496
497
498
499
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
521
@multi_gpu_test(num_gpus=2)
522
523
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
524
525
526
527
528
529
530
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
531
    disable_deepgemm_ue8m0,
532
    workspace_init,
bnellnm's avatar
bnellnm committed
533
):
534
535
536
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """
537
    assert not is_deep_gemm_e8m0_used()
538
539

    m, n, k = mnk
540
    set_random_seed(7)
541
542
543
544

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

545
546
547
548
549
550
551
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
552
        per_act_token_quant=False,
553
        block_size=block_size,
554
555
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
556
557
558
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
559
560
        num_experts, n, k, block_size
    )
561

562
563
564
565
566
567
568
569
570
571
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )