test_deepep_deepgemm_moe.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
bnellnm's avatar
bnellnm committed
4
Test DeepEP + DeepGEMM integration
5
6
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
7
8
9
"""

import dataclasses
10
from contextlib import contextmanager
11
12
13
14
15
16
17

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
18
from vllm.forward_context import set_forward_context
19
from vllm.model_executor.layers.fused_moe.config import (
20
21
22
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
23
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
24
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
25
26
27
28
29
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
30
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
31
from vllm.utils.torch_utils import set_random_seed
32
from vllm.v1.worker.workspace import init_workspace_manager
33

34
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
35
36
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
37

38
if has_deep_ep():
39
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
40
41
        DeepEPHTPrepareAndFinalize,
    )
42
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
43
44
        DeepEPLLPrepareAndFinalize,
    )
45

bnellnm's avatar
bnellnm committed
46
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
47

48
if has_deep_gemm():
49
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
50
51
52
        BatchedDeepGemmExperts,
    )
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
53
54

requires_deep_ep = pytest.mark.skipif(
55
    not has_deep_ep(),
56
57
58
59
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
60
    not is_deep_gemm_supported(),
61
62
63
64
65
66
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@contextmanager
def with_dp_metadata(M: int, world_size: int):
    num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)

    vllm_config = VllmConfig()
    vllm_config.parallel_config.data_parallel_size = world_size
    vllm_config.parallel_config.enable_expert_parallel = True

    with set_forward_context(
        None,
        vllm_config,
        num_tokens=M,
        num_tokens_across_dp=num_tokens_across_dp,
    ):
        yield


84
85
def next_power_of_2(x):
    import math
86

87
88
    if x == 0:
        return 1
89
    return 2 ** math.ceil(math.log2(x))
90
91


92
93
94
95
96
97
98
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
99
    Return weights w1q, w2q, w1_scale, w2_scale
100
    """
101
102
103
    (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
    )
bnellnm's avatar
bnellnm committed
104
    return w1q, w2q, w1_scale, w2_scale
105
106
107
108
109
110
111
112
113


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
114
    per_act_token_quant: bool
115
    block_size: list[int]
116
117
    # configs for testing low-latency kernels
    low_latency: bool
118
    use_fp8_dispatch: bool | None = False
119
120
121
122
123


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
124
    rank_token_scales: torch.Tensor | None
125
126
127
128
129
130
131
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":
        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
132
        topk, m, k = (config.topk, config.m, config.k)
133
134
135
136

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

137
138
139
        rank_tokens = (
            torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        )
140
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
141
        rank_token_scales = None
142
143
144
145
146

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
147
148
            device=torch.cuda.current_device(),
        ).to(dtype=torch.int64)
149

150
151
152
        topk_weights = torch.randn(
            topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
        )
153

154
155
156
157
158
159
160
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk_ids,
            topk_weights=topk_weights,
            config=config,
        )
161
162


163
def make_ll_modular_kernel(
164
165
166
167
168
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    max_tokens_per_rank: int,
    dp_size: int,
    hidden_size: int,
169
    q_dtype: torch.dtype | None,
170
171
172
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
173
174
175
176
177
178
179
180
181
182
183
184
    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
185
186
            use_fp8_dispatch=test_config.use_fp8_dispatch,
        ),
187
        q_dtype=q_dtype,
188
189
        block_shape=test_config.block_size,
    )
190

bnellnm's avatar
bnellnm committed
191
192
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
193
        num_dispatchers=pgi.world_size // dp_size,
194
195
        quant_config=quant_config,
    )
196
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
197
198
199
    return mk


200
def make_ht_modular_kernel(
201
202
203
204
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
205
    q_dtype: torch.dtype | None,
206
207
208
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
209
210
    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
211
212
213
214
215
216
217
218

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
219
220
        block_shape=test_config.block_size,
    )
221

222
    fused_experts = DeepGemmExperts(quant_config)
223
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
224
225
226
    return mk


227
def make_modular_kernel(
228
229
230
231
232
233
234
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
    test_tensors: TestTensors,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
235
236
237
238
239
240
    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
241
        max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
242
243
        hidden_size = test_tensors.rank_tokens.size(-1)

244
245
246
247
248
249
250
251
252
253
        mk = make_ll_modular_kernel(
            pg=pg,
            pgi=pgi,
            max_tokens_per_rank=max_tokens_per_rank,
            dp_size=dp_size,
            hidden_size=hidden_size,
            q_dtype=q_dtype,
            test_config=test_config,
            quant_config=quant_config,
        )
254
    else:
255
256
257
258
259
260
261
262
263
        mk = make_ht_modular_kernel(
            pg,
            pgi,
            dp_size,
            num_local_experts,
            q_dtype,
            test_config,
            quant_config=quant_config,
        )
264
265
266
267

    return mk


268
269
270
271
272
273
274
def deepep_deepgemm_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
275
276
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
277
) -> torch.Tensor:
278
279
    test_config = test_tensors.config
    num_experts = test_config.num_experts
280
281
282
283
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
284
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
285
286
287
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
288
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
289

290
291
292
293
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        # Low-Latency kernels can't dispatch scales.
294
        a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
295
296
297
        block_shape=test_config.block_size,
    )

298
299
    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
300
301
302
303
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
304
        test_tensors=test_tensors,
305
306
        quant_config=quant_config,
    )
307

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    with with_dp_metadata(
        M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
    ):
        out = mk.forward(
            hidden_states=test_tensors.rank_tokens,
            w1=w1,
            w2=w2,
            topk_weights=test_tensors.topk_weights,
            topk_ids=test_tensors.topk,
            inplace=False,
            activation="silu",
            global_num_experts=num_experts,
            expert_map=build_expert_map(),
            apply_router_weight_on_input=False,
        )
323
    return out
324
325


326
327
328
329
330
331
332
333
334
335
336
def triton_impl(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    a1_scale: torch.Tensor,
    block_shape: list[int],
):
337
338
339
340
341
342
343
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
    )

344
345
346
347
348
349
350
    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
351
        quant_config=quant_config,
352
        # Make sure this is set to False so we
353
        # don't end up comparing the same implementation.
354
355
        allow_deep_gemm=False,
    )
356
357


358
def _test_deepep_deepgemm_moe(
359
360
361
362
363
364
365
366
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
367
368
369
    device = torch.device(f"cuda:{pgi.local_rank}")
    init_workspace_manager(device)

370
    set_random_seed(pgi.rank)
371
372
373
374
375
376
377
378

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
379
    block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
380
381
382

    with set_current_vllm_config(VllmConfig()):
        # Reference
383
384
385
386
387
388
389
390
391
392
393
        triton_moe = triton_impl(
            a=test_tensors.rank_tokens,
            topk_ids=test_tensors.topk,
            topk_weights=test_tensors.topk_weights,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=test_tensors.rank_token_scales,
            block_shape=block_shape,
        )
394
395
396
397
398
399
400
401
402
403

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

404
        deepep_moe = deepep_deepgemm_moe_impl(
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

435
436
437
TOPKS = [2, 6]
NUM_EXPERTS = [32]

438
439

@pytest.mark.parametrize("mnk", MNKs)
440
441
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
442
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
443
@multi_gpu_test(num_gpus=2)
444
445
@requires_deep_ep
@requires_deep_gemm
446
447
448
449
450
def test_ht_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
451
    disable_deepgemm_ue8m0,
452
    workspace_init,
453
):
454
455
456
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
457
458

    m, n, k = mnk
459
    set_random_seed(7)
460
461
462
463

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

464
    block_m = get_mk_alignment_for_contiguous_layout()[0]
465
466
    block_size = [block_m, block_m]

467
    world_size, dp_size = world_dp_size
468
469
470
471
472
473
474
475
476
477
478
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
        per_act_token_quant=False,
        block_size=block_size,
        low_latency=False,
        use_fp8_dispatch=None,
    )
479
480

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
481
482
        num_experts, n, k, block_size
    )
483

484
485
486
487
488
489
490
491
492
493
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
515
@multi_gpu_test(num_gpus=2)
516
517
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
518
519
520
521
522
523
524
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
525
    disable_deepgemm_ue8m0,
526
    workspace_init,
bnellnm's avatar
bnellnm committed
527
):
528
529
530
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """
531
    assert not is_deep_gemm_e8m0_used()
532
533

    m, n, k = mnk
534
    set_random_seed(7)
535
536
537
538

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

539
540
541
542
543
544
545
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
546
        per_act_token_quant=False,
547
        block_size=block_size,
548
549
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
550
551
552
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
553
554
        num_experts, n, k, block_size
    )
555

556
557
558
559
560
561
562
563
564
565
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )