test_deepep_deepgemm_moe.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
bnellnm's avatar
bnellnm committed
4
Test DeepEP + DeepGEMM integration
5
6
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
7
8
9
"""

import dataclasses
10
from contextlib import contextmanager
11
12
13
14
15
16
17

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
18
from vllm.forward_context import set_forward_context
19
from vllm.model_executor.layers.fused_moe.config import (
20
21
22
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
23
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
24
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
25
from vllm.platforms import current_platform
26
27
28
29
30
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
31
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
32

33
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
34
35
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
36

37
if has_deep_ep():
38
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
39
40
        DeepEPHTPrepareAndFinalize,
    )
41
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
42
43
        DeepEPLLPrepareAndFinalize,
    )
44

bnellnm's avatar
bnellnm committed
45
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
46

47
if has_deep_gemm():
48
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
49
50
51
        BatchedDeepGemmExperts,
    )
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
52
53

requires_deep_ep = pytest.mark.skipif(
54
    not has_deep_ep(),
55
56
57
58
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
59
    not is_deep_gemm_supported(),
60
61
62
63
64
65
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@contextmanager
def with_dp_metadata(M: int, world_size: int):
    num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)

    vllm_config = VllmConfig()
    vllm_config.parallel_config.data_parallel_size = world_size
    vllm_config.parallel_config.enable_expert_parallel = True

    with set_forward_context(
        None,
        vllm_config,
        num_tokens=M,
        num_tokens_across_dp=num_tokens_across_dp,
    ):
        yield


83
84
def next_power_of_2(x):
    import math
85

86
87
    if x == 0:
        return 1
88
    return 2 ** math.ceil(math.log2(x))
89
90


91
92
93
94
95
96
97
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
98
    Return weights w1q, w2q, w1_scale, w2_scale
99
    """
100
101
102
    (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
    )
bnellnm's avatar
bnellnm committed
103
    return w1q, w2q, w1_scale, w2_scale
104
105
106
107
108
109
110
111
112


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
113
    per_act_token_quant: bool
114
    block_size: list[int]
115
116
    # configs for testing low-latency kernels
    low_latency: bool
117
    use_fp8_dispatch: bool | None = False
118
119
120
121
122


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
123
    rank_token_scales: torch.Tensor | None
124
125
126
127
128
129
130
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":
        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
131
        topk, m, k = (config.topk, config.m, config.k)
132
133
134
135

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

136
137
138
        rank_tokens = (
            torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        )
139
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
140
        rank_token_scales = None
141
142
143
144
145

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
146
147
            device=torch.cuda.current_device(),
        ).to(dtype=torch.int64)
148

149
150
151
        topk_weights = torch.randn(
            topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
        )
152

153
154
155
156
157
158
159
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk_ids,
            topk_weights=topk_weights,
            config=config,
        )
160
161


162
def make_ll_modular_kernel(
163
164
165
166
167
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    max_tokens_per_rank: int,
    dp_size: int,
    hidden_size: int,
168
    q_dtype: torch.dtype | None,
169
170
171
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
172
173
174
175
176
177
178
179
180
181
182
183
    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
184
185
            use_fp8_dispatch=test_config.use_fp8_dispatch,
        ),
186
        q_dtype=q_dtype,
187
188
        block_shape=test_config.block_size,
    )
189

bnellnm's avatar
bnellnm committed
190
191
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
192
        num_dispatchers=pgi.world_size // dp_size,
193
194
        quant_config=quant_config,
    )
195
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
196
197
198
    return mk


199
def make_ht_modular_kernel(
200
201
202
203
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
204
    q_dtype: torch.dtype | None,
205
206
207
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
208
209
    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
210
211
212
213
214
215
216
217

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
218
219
        block_shape=test_config.block_size,
    )
220

221
    fused_experts = DeepGemmExperts(quant_config)
222
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
223
224
225
    return mk


226
def make_modular_kernel(
227
228
229
230
231
232
233
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
    test_tensors: TestTensors,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
234
235
236
237
238
239
    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
240
        max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
241
242
        hidden_size = test_tensors.rank_tokens.size(-1)

243
244
245
246
247
248
249
250
251
252
        mk = make_ll_modular_kernel(
            pg=pg,
            pgi=pgi,
            max_tokens_per_rank=max_tokens_per_rank,
            dp_size=dp_size,
            hidden_size=hidden_size,
            q_dtype=q_dtype,
            test_config=test_config,
            quant_config=quant_config,
        )
253
    else:
254
255
256
257
258
259
260
261
262
        mk = make_ht_modular_kernel(
            pg,
            pgi,
            dp_size,
            num_local_experts,
            q_dtype,
            test_config,
            quant_config=quant_config,
        )
263
264
265
266

    return mk


267
268
269
270
271
272
273
def deepep_deepgemm_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
274
275
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
276
) -> torch.Tensor:
277
278
    test_config = test_tensors.config
    num_experts = test_config.num_experts
279
280
281
282
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
283
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
284
285
286
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
287
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
288

289
290
291
292
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        # Low-Latency kernels can't dispatch scales.
293
        a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
294
295
296
        block_shape=test_config.block_size,
    )

297
298
    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
299
300
301
302
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
303
        test_tensors=test_tensors,
304
305
        quant_config=quant_config,
    )
306

307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    with with_dp_metadata(
        M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
    ):
        out = mk.forward(
            hidden_states=test_tensors.rank_tokens,
            w1=w1,
            w2=w2,
            topk_weights=test_tensors.topk_weights,
            topk_ids=test_tensors.topk,
            inplace=False,
            activation="silu",
            global_num_experts=num_experts,
            expert_map=build_expert_map(),
            apply_router_weight_on_input=False,
        )
322
    return out
323
324


325
326
327
328
329
330
331
332
333
334
335
def triton_impl(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    a1_scale: torch.Tensor,
    block_shape: list[int],
):
336
337
338
339
340
341
342
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
    )

343
344
345
346
347
348
349
    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
350
        quant_config=quant_config,
351
        # Make sure this is set to False so we
352
        # don't end up comparing the same implementation.
353
354
        allow_deep_gemm=False,
    )
355
356


357
def _test_deepep_deepgemm_moe(
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
    current_platform.seed_everything(pgi.rank)

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
375
    block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
376
377
378

    with set_current_vllm_config(VllmConfig()):
        # Reference
379
380
381
382
383
384
385
386
387
388
389
        triton_moe = triton_impl(
            a=test_tensors.rank_tokens,
            topk_ids=test_tensors.topk,
            topk_weights=test_tensors.topk_weights,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=test_tensors.rank_token_scales,
            block_shape=block_shape,
        )
390
391
392
393
394
395
396
397
398
399

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

400
        deepep_moe = deepep_deepgemm_moe_impl(
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

431
432
433
TOPKS = [2, 6]
NUM_EXPERTS = [32]

434
435

@pytest.mark.parametrize("mnk", MNKs)
436
437
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
438
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
439
@multi_gpu_test(num_gpus=2)
440
441
@requires_deep_ep
@requires_deep_gemm
442
443
444
445
446
def test_ht_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
447
    disable_deepgemm_ue8m0,
448
):
449
450
451
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
452
453
454
455
456
457
458

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

459
    block_m = get_mk_alignment_for_contiguous_layout()[0]
460
461
    block_size = [block_m, block_m]

462
    world_size, dp_size = world_dp_size
463
464
465
466
467
468
469
470
471
472
473
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
        per_act_token_quant=False,
        block_size=block_size,
        low_latency=False,
        use_fp8_dispatch=None,
    )
474
475

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
476
477
        num_experts, n, k, block_size
    )
478

479
480
481
482
483
484
485
486
487
488
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
510
@multi_gpu_test(num_gpus=2)
511
512
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
513
514
515
516
517
518
519
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
520
    disable_deepgemm_ue8m0,
bnellnm's avatar
bnellnm committed
521
):
522
523
524
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """
525
    assert not is_deep_gemm_e8m0_used()
526
527
528
529
530
531
532

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

533
534
535
536
537
538
539
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
540
        per_act_token_quant=False,
541
        block_size=block_size,
542
543
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
544
545
546
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
547
548
        num_experts, n, k, block_size
    )
549

550
551
552
553
554
555
556
557
558
559
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )