test_deepep_deepgemm_moe.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
bnellnm's avatar
bnellnm committed
4
Test DeepEP + DeepGEMM integration
5
6
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
7
8
9
10
11
12
13
14
15
16
"""

import dataclasses

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
17
from vllm.model_executor.layers.fused_moe.config import (
18
19
20
    FusedMoEQuantConfig,
    fp8_w8a8_moe_quant_config,
)
21
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
22
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
23
from vllm.platforms import current_platform
24
25
26
27
28
from vllm.utils.deep_gemm import (
    get_mk_alignment_for_contiguous_layout,
    is_deep_gemm_e8m0_used,
    is_deep_gemm_supported,
)
29
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
30

31
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
32
33
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
34

35
if has_deep_ep():
36
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
37
38
        DeepEPHTPrepareAndFinalize,
    )
39
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
40
41
        DeepEPLLPrepareAndFinalize,
    )
42

bnellnm's avatar
bnellnm committed
43
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
44

45
if has_deep_gemm():
46
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
47
48
49
        BatchedDeepGemmExperts,
    )
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
50
51

requires_deep_ep = pytest.mark.skipif(
52
    not has_deep_ep(),
53
54
55
56
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
57
    not is_deep_gemm_supported(),
58
59
60
61
62
63
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


64
65
def next_power_of_2(x):
    import math
66

67
68
    if x == 0:
        return 1
69
    return 2 ** math.ceil(math.log2(x))
70
71


72
73
74
75
76
77
78
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
79
    Return weights w1q, w2q, w1_scale, w2_scale
80
    """
81
82
83
    (_, w1q, w1_scale, _), (_, w2q, w2_scale, _) = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_shape=block_size
    )
bnellnm's avatar
bnellnm committed
84
    return w1q, w2q, w1_scale, w2_scale
85
86
87
88
89
90
91
92
93


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
94
    per_act_token_quant: bool
95
    block_size: list[int]
96
97
    # configs for testing low-latency kernels
    low_latency: bool
98
    use_fp8_dispatch: bool | None = False
99
100
101
102
103


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
104
    rank_token_scales: torch.Tensor | None
105
106
107
108
109
110
111
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":
        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
112
        topk, m, k = (config.topk, config.m, config.k)
113
114
115
116

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

117
118
119
        rank_tokens = (
            torch.randn((m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        )
120
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
121
        rank_token_scales = None
122
123
124
125
126

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
127
128
            device=torch.cuda.current_device(),
        ).to(dtype=torch.int64)
129

130
131
132
        topk_weights = torch.randn(
            topk_ids.shape, dtype=torch.float32, device=torch.cuda.current_device()
        )
133

134
135
136
137
138
139
140
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk_ids,
            topk_weights=topk_weights,
            config=config,
        )
141
142


143
def make_ll_modular_kernel(
144
145
146
147
148
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    max_tokens_per_rank: int,
    dp_size: int,
    hidden_size: int,
149
    q_dtype: torch.dtype | None,
150
151
152
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
153
154
155
156
157
158
159
160
161
162
163
164
    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
165
166
            use_fp8_dispatch=test_config.use_fp8_dispatch,
        ),
167
        q_dtype=q_dtype,
168
169
        block_shape=test_config.block_size,
    )
170

bnellnm's avatar
bnellnm committed
171
172
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
173
        num_dispatchers=pgi.world_size // dp_size,
174
175
        quant_config=quant_config,
    )
176
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
177
178
179
    return mk


180
def make_ht_modular_kernel(
181
182
183
184
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
185
    q_dtype: torch.dtype | None,
186
187
188
    test_config: TestConfig,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
189
190
    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
191
192
193
194
195
196
197
198

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
199
200
        block_shape=test_config.block_size,
    )
201

202
    fused_experts = DeepGemmExperts(quant_config)
203
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
204
205
206
    return mk


207
def make_modular_kernel(
208
209
210
211
212
213
214
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    num_local_experts: int,
    test_tensors: TestTensors,
    quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
215
216
217
218
219
220
    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
221
        max_tokens_per_rank = max(64, next_power_of_2(test_tensors.rank_tokens.size(0)))
222
223
        hidden_size = test_tensors.rank_tokens.size(-1)

224
225
226
227
228
229
230
231
232
233
        mk = make_ll_modular_kernel(
            pg=pg,
            pgi=pgi,
            max_tokens_per_rank=max_tokens_per_rank,
            dp_size=dp_size,
            hidden_size=hidden_size,
            q_dtype=q_dtype,
            test_config=test_config,
            quant_config=quant_config,
        )
234
    else:
235
236
237
238
239
240
241
242
243
        mk = make_ht_modular_kernel(
            pg,
            pgi,
            dp_size,
            num_local_experts,
            q_dtype,
            test_config,
            quant_config=quant_config,
        )
244
245
246
247

    return mk


248
249
250
251
252
253
254
def deepep_deepgemm_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
255
256
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
257
) -> torch.Tensor:
258
259
    test_config = test_tensors.config
    num_experts = test_config.num_experts
260
261
262
263
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
264
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
265
266
267
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
268
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
269

270
271
272
273
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        # Low-Latency kernels can't dispatch scales.
274
        a1_scale=(None if test_config.low_latency else test_tensors.rank_token_scales),
275
276
277
        block_shape=test_config.block_size,
    )

278
279
    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
280
281
282
283
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
284
        test_tensors=test_tensors,
285
286
        quant_config=quant_config,
    )
287

288
289
290
291
292
293
294
295
296
297
298
299
300
    out = mk.forward(
        hidden_states=test_tensors.rank_tokens,
        w1=w1,
        w2=w2,
        topk_weights=test_tensors.topk_weights,
        topk_ids=test_tensors.topk,
        inplace=False,
        activation="silu",
        global_num_experts=num_experts,
        expert_map=build_expert_map(),
        apply_router_weight_on_input=False,
    )
    return out
301
302


303
304
305
306
307
308
309
310
311
312
313
def triton_impl(
    a: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    a1_scale: torch.Tensor,
    block_shape: list[int],
):
314
315
316
317
318
319
320
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
    )

321
322
323
324
325
326
327
    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
328
        quant_config=quant_config,
329
        # Make sure this is set to False so we
330
        # don't end up comparing the same implementation.
331
332
        allow_deep_gemm=False,
    )
333
334


335
def _test_deepep_deepgemm_moe(
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
    current_platform.seed_everything(pgi.rank)

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
353
    block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)]
354
355
356

    with set_current_vllm_config(VllmConfig()):
        # Reference
357
358
359
360
361
362
363
364
365
366
367
        triton_moe = triton_impl(
            a=test_tensors.rank_tokens,
            topk_ids=test_tensors.topk,
            topk_weights=test_tensors.topk_weights,
            w1=w1,
            w2=w2,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a1_scale=test_tensors.rank_token_scales,
            block_shape=block_shape,
        )
368
369
370
371
372
373
374
375
376
377

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

378
        deepep_moe = deepep_deepgemm_moe_impl(
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

409
410
411
TOPKS = [2, 6]
NUM_EXPERTS = [32]

412
413

@pytest.mark.parametrize("mnk", MNKs)
414
415
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
416
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
417
@multi_gpu_test(num_gpus=2)
418
419
@requires_deep_ep
@requires_deep_gemm
420
421
422
423
424
def test_ht_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
425
    disable_deepgemm_ue8m0,
426
):
427
428
429
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
430
431
432
433
434
435
436

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

437
    block_m = get_mk_alignment_for_contiguous_layout()[0]
438
439
    block_size = [block_m, block_m]

440
    world_size, dp_size = world_dp_size
441
442
443
444
445
446
447
448
449
450
451
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
        per_act_token_quant=False,
        block_size=block_size,
        low_latency=False,
        use_fp8_dispatch=None,
    )
452
453

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
454
455
        num_experts, n, k, block_size
    )
456

457
458
459
460
461
462
463
464
465
466
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
488
@multi_gpu_test(num_gpus=2)
489
490
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
491
492
493
494
495
496
497
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
498
    disable_deepgemm_ue8m0,
bnellnm's avatar
bnellnm committed
499
):
500
501
502
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """
503
    assert not is_deep_gemm_e8m0_used()
504
505
506
507
508
509
510

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

511
512
513
514
515
516
517
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
518
        per_act_token_quant=False,
519
        block_size=block_size,
520
521
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
522
523
524
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
525
526
        num_experts, n, k, block_size
    )
527

528
529
530
531
532
533
534
535
536
537
    parallel_launch(
        world_size,
        _test_deepep_deepgemm_moe,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
    )