test_deepep_deepgemm_moe.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
bnellnm's avatar
bnellnm committed
4
Test DeepEP + DeepGEMM integration
5
6
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
7
8
9
"""

import dataclasses
10
from contextlib import contextmanager
11
12
13
14
15
16
17

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
18
from vllm.forward_context import set_forward_context
19
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
20
from vllm.model_executor.layers.fused_moe.config import (
21
22
23
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
24
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
25
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
26
27
28
29
30
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
31
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
32
from vllm.utils.torch_utils import set_random_seed
33
from vllm.v1.worker.workspace import init_workspace_manager
34

35
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
36
from .parallel_utils import ProcessGroupInfo, parallel_launch
37
from .utils import make_dummy_moe_config, make_test_weights
38

39
if has_deep_ep():
40
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
41
42
        DeepEPHTPrepareAndFinalize,
    )
43
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
44
45
        DeepEPLLPrepareAndFinalize,
    )
46

bnellnm's avatar
bnellnm committed
47
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
48

49
if has_deep_gemm():
50
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
51
52
53
        BatchedDeepGemmExperts,
    )
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
54
55

requires_deep_ep = pytest.mark.skipif(
56
    not has_deep_ep(),
57
58
59
60
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
61
    not is_deep_gemm_supported(),
62
63
64
65
66
67
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@contextmanager
def with_dp_metadata(M: int, world_size: int):
    num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)

    vllm_config = VllmConfig()
    vllm_config.parallel_config.data_parallel_size = world_size
    vllm_config.parallel_config.enable_expert_parallel = True

    with set_forward_context(
        None,
        vllm_config,
        num_tokens=M,
        num_tokens_across_dp=num_tokens_across_dp,
    ):
        yield


85
86
def next_power_of_2(x):
    import math
87

88
89
    if x == 0:
        return 1
90
    return 2 ** math.ceil(math.log2(x))
91
92


93
94
95
96
97
98
99
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
100
    Return weights w1q, w2q, w1_scale, w2_scale
101
    """
102
103
104
    (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
    )
bnellnm's avatar
bnellnm committed
105
    return w1q, w2q, w1_scale, w2_scale
106
107
108
109
110
111
112
113
114


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
115
    per_act_token_quant: bool
116
    block_size: list[int]
117
118
    # configs for testing low-latency kernels
    low_latency: bool
119
    use_fp8_dispatch: bool | None = False
120
121
122
123
124


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
125
    rank_token_scales: torch.Tensor | None
126
127
128
129
130
131
132
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":
        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
133
        topk, m, k = (config.topk, config.m, config.k)
134
135
136
137

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

138
139
140
        rank_tokens = (
            torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        )
141
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
142
        rank_token_scales = None
143
144
145
146
147

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
148
149
            device=torch.cuda.current_device(),
        ).to(dtype=torch.int64)
150

151
152
153
        topk_weights = torch.randn(
            topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
        )
154

155
156
157
158
159
160
161
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk_ids,
            topk_weights=topk_weights,
            config=config,
        )
162
163


164
def make_ll_modular_kernel(
165
166
167
168
169
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    max_tokens_per_rank: int,
    dp_size: int,
    hidden_size: int,
170
    q_dtype: torch.dtype | None,
171
172
173
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
174
175
176
177
178
179
180
181
182
183
184
185
    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
186
187
            use_fp8_dispatch=test_config.use_fp8_dispatch,
        ),
188
        q_dtype=q_dtype,
189
190
        block_shape=test_config.block_size,
    )
191

bnellnm's avatar
bnellnm committed
192
193
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
194
        num_dispatchers=pgi.world_size // dp_size,
195
        quant_config=quant_config,
196
        moe_config=make_dummy_moe_config(),
197
    )
198
199
200
201
202
    return FusedMoEModularKernel(
        prepare_finalize=a2a,
        fused_experts=fused_experts,
        inplace=False,
    )
203
204


205
def make_ht_modular_kernel(
206
207
208
209
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
210
    q_dtype: torch.dtype | None,
211
212
213
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
214
215
    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
216
217
218
219
220
221
222
223

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
224
225
        block_shape=test_config.block_size,
    )
226

227
228
229
230
    fused_experts = DeepGemmExperts(
        moe_config=make_dummy_moe_config(),
        quant_config=quant_config,
    )
231
232
233
234
235
    return FusedMoEModularKernel(
        prepare_finalize=a2a,
        fused_experts=fused_experts,
        inplace=False,
    )
236
237


238
def make_modular_kernel(
239
240
241
242
243
244
245
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
    test_tensors: TestTensors,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
246
247
248
249
250
251
    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
252
        max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
253
254
        hidden_size = test_tensors.rank_tokens.size(-1)

255
256
257
258
259
260
261
262
263
264
        mk = make_ll_modular_kernel(
            pg=pg,
            pgi=pgi,
            max_tokens_per_rank=max_tokens_per_rank,
            dp_size=dp_size,
            hidden_size=hidden_size,
            q_dtype=q_dtype,
            test_config=test_config,
            quant_config=quant_config,
        )
265
    else:
266
267
268
269
270
271
272
273
274
        mk = make_ht_modular_kernel(
            pg,
            pgi,
            dp_size,
            num_local_experts,
            q_dtype,
            test_config,
            quant_config=quant_config,
        )
275
276
277
278

    return mk


279
280
281
282
283
284
285
def deepep_deepgemm_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
286
287
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
288
) -> torch.Tensor:
289
290
    test_config = test_tensors.config
    num_experts = test_config.num_experts
291
292
293
294
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
295
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
296
297
298
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
299
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
300

301
302
303
304
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        # Low-Latency kernels can't dispatch scales.
305
        a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
306
307
308
        block_shape=test_config.block_size,
    )

309
310
    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
311
312
313
314
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
315
        test_tensors=test_tensors,
316
317
        quant_config=quant_config,
    )
318

319
320
321
322
323
324
325
326
327
    with with_dp_metadata(
        M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
    ):
        out = mk.forward(
            hidden_states=test_tensors.rank_tokens,
            w1=w1,
            w2=w2,
            topk_weights=test_tensors.topk_weights,
            topk_ids=test_tensors.topk,
328
            activation=MoEActivation.SILU,
329
330
331
332
            global_num_experts=num_experts,
            expert_map=build_expert_map(),
            apply_router_weight_on_input=False,
        )
333
    return out
334
335


336
337
338
339
340
341
342
343
344
345
346
def triton_impl(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    a1_scale: torch.Tensor,
    block_shape: list[int],
):
347
348
349
350
351
352
353
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
    )

354
355
356
357
358
359
360
    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
361
        quant_config=quant_config,
362
    )
363
364


365
def _test_deepep_deepgemm_moe(
366
367
368
369
370
371
372
373
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
374
375
376
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

377
    set_random_seed(pgi.rank)
378
379
380
381
382
383
384
385

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
386
    block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
387
388
389

    with set_current_vllm_config(VllmConfig()):
        # Reference
390
391
392
393
394
395
396
397
398
399
400
        triton_moe = triton_impl(
            a=test_tensors.rank_tokens,
            topk_ids=test_tensors.topk,
            topk_weights=test_tensors.topk_weights,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=test_tensors.rank_token_scales,
            block_shape=block_shape,
        )
401
402
403
404
405
406
407
408
409
410

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

411
        deepep_moe = deepep_deepgemm_moe_impl(
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

442
443
444
TOPKS = [2, 6]
NUM_EXPERTS = [32]

445
446

@pytest.mark.parametrize("mnk", MNKs)
447
448
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
449
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
450
@multi_gpu_test(num_gpus=2)
451
452
@requires_deep_ep
@requires_deep_gemm
453
454
455
456
457
def test_ht_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
458
    disable_deepgemm_ue8m0,
459
    workspace_init,
460
):
461
462
463
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
464
465

    m, n, k = mnk
466
    set_random_seed(7)
467
468
469
470

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

471
    block_m = get_mk_alignment_for_contiguous_layout()[0]
472
473
    block_size = [block_m, block_m]

474
    world_size, dp_size = world_dp_size
475
476
477
478
479
480
481
482
483
484
485
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
        per_act_token_quant=False,
        block_size=block_size,
        low_latency=False,
        use_fp8_dispatch=None,
    )
486
487

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
488
489
        num_experts, n, k, block_size
    )
490

491
492
493
494
495
496
497
498
499
500
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
522
@multi_gpu_test(num_gpus=2)
523
524
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
525
526
527
528
529
530
531
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
532
    disable_deepgemm_ue8m0,
533
    workspace_init,
bnellnm's avatar
bnellnm committed
534
):
535
536
537
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """
538
    assert not is_deep_gemm_e8m0_used()
539
540

    m, n, k = mnk
541
    set_random_seed(7)
542
543
544
545

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

546
547
548
549
550
551
552
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
553
        per_act_token_quant=False,
554
        block_size=block_size,
555
556
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
557
558
559
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
560
561
        num_experts, n, k, block_size
    )
562

563
564
565
566
567
568
569
570
571
572
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )