test_deepep_deepgemm_moe.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
"""
bnellnm's avatar
bnellnm committed
3
Test DeepEP + DeepGEMM integration
4
5
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""

import dataclasses
from typing import Optional

import pytest
import torch.distributed
from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
    FusedMoEModularKernel)
from vllm.platforms import current_platform
21
from vllm.utils import has_deep_ep, has_deep_gemm
22

bnellnm's avatar
bnellnm committed
23
24
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
25

26
if has_deep_ep():
27
28
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (  # noqa: E501
        DeepEPHTPrepareAndFinalize)
29
30
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (  # noqa: E501
        DeepEPLLPrepareAndFinalize)
31

bnellnm's avatar
bnellnm committed
32
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
33

34
35
if has_deep_gemm():

36
37
    from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
        BatchedDeepGemmExperts)
38
39
40
41
    from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
        DeepGemmExperts)

requires_deep_ep = pytest.mark.skipif(
42
    not has_deep_ep(),
43
44
45
46
    reason="Requires deep_ep kernels",
)

requires_deep_gemm = pytest.mark.skipif(
47
    not has_deep_gemm(),
48
49
50
51
52
53
    reason="Requires deep_gemm kernels",
)

P = ParamSpec("P")


54
55
56
57
58
59
60
def next_power_of_2(x):
    import math
    if x == 0:
        return 1
    return 2**math.ceil(math.log2(x))


61
62
63
64
65
66
67
def make_block_quant_fp8_weights(
    e: int,
    n: int,
    k: int,
    block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
bnellnm's avatar
bnellnm committed
68
    Return weights w1q, w2q, w1_scale, w2_scale
69
    """
bnellnm's avatar
bnellnm committed
70
71
72
    w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
        e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
    return w1q, w2q, w1_scale, w2_scale
73
74
75
76
77
78
79
80
81


@dataclasses.dataclass
class TestConfig:
    topk: int
    m: int
    k: int
    n: int
    num_experts: int
bnellnm's avatar
bnellnm committed
82
    per_act_token_quant: bool
83
    block_size: list[int]
84
85
86
    # configs for testing low-latency kernels
    low_latency: bool
    use_fp8_dispatch: Optional[bool] = False
87
88
89
90
91
92
93
94
95
96
97
98
99
100


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
    rank_token_scales: Optional[torch.Tensor]
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, rank) -> "TestTensors":

        dtype = torch.bfloat16
bnellnm's avatar
bnellnm committed
101
        topk, m, k = (config.topk, config.m, config.k)
102
103
104
105
106
107
108

        fp8_info = torch.finfo(torch.float8_e4m3fn)
        fp8_max, fp8_min = fp8_info.max, fp8_info.min

        rank_tokens = torch.randn(
            (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
        rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
bnellnm's avatar
bnellnm committed
109
        rank_token_scales = None
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

        topk_ids = torch.randint(
            low=0,
            high=config.num_experts,
            size=(m, topk),
            device=torch.cuda.current_device()).to(dtype=torch.int64)

        topk_weights = torch.randn(topk_ids.shape,
                                   dtype=torch.float32,
                                   device=torch.cuda.current_device())

        return TestTensors(rank_tokens=rank_tokens,
                           rank_token_scales=rank_token_scales,
                           topk=topk_ids,
                           topk_weights=topk_weights,
                           config=config)


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
                           max_tokens_per_rank: int, dp_size: int,
                           hidden_size: int, q_dtype: Optional[torch.dtype],
                           test_config: TestConfig) -> FusedMoEModularKernel:

    assert test_config.low_latency
    assert test_config.use_fp8_dispatch is not None

    a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=None,
        deepep_ll_args=DeepEPLLArgs(
            max_tokens_per_rank=max_tokens_per_rank,
            hidden_size=hidden_size,
            num_experts=test_config.num_experts,
            use_fp8_dispatch=test_config.use_fp8_dispatch),
        q_dtype=q_dtype,
        block_shape=test_config.block_size)

bnellnm's avatar
bnellnm committed
149
150
151
152
153
154
    fused_experts = BatchedDeepGemmExperts(
        max_num_tokens=max_tokens_per_rank,
        world_size=pgi.world_size,
        dp_size=dp_size,
        block_shape=test_config.block_size,
        per_act_token_quant=test_config.per_act_token_quant)
155
156
157
158
159
160
161
162
163
164
165
166
    mk = FusedMoEModularKernel(prepare_finalize=a2a,
                               fused_experts=fused_experts)
    return mk


def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
                           dp_size: int, num_local_experts: int,
                           q_dtype: Optional[torch.dtype],
                           test_config: TestConfig) -> FusedMoEModularKernel:

    assert not test_config.low_latency
    assert test_config.use_fp8_dispatch is None
167
168
169
170
171
172
173
174

    a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts),
        deepep_ll_args=None,
        q_dtype=q_dtype,
175
        block_shape=test_config.block_size)
176
177
178
179
180
181
182

    fused_experts = DeepGemmExperts()
    mk = FusedMoEModularKernel(prepare_finalize=a2a,
                               fused_experts=fused_experts)
    return mk


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
                        num_local_experts: int,
                        test_tensors: TestTensors) -> FusedMoEModularKernel:

    q_dtype = torch.float8_e4m3fn
    test_config = test_tensors.config

    mk: FusedMoEModularKernel
    # Make modular kernel
    if test_config.low_latency:
        max_tokens_per_rank = max(
            64, next_power_of_2(test_tensors.rank_tokens.size(0)))
        hidden_size = test_tensors.rank_tokens.size(-1)

        mk = make_ll_modular_kernel(pg=pg,
                                    pgi=pgi,
                                    max_tokens_per_rank=max_tokens_per_rank,
                                    dp_size=dp_size,
                                    hidden_size=hidden_size,
                                    q_dtype=q_dtype,
                                    test_config=test_config)
    else:
        mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
                                    q_dtype, test_config)

    return mk


def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
                             dp_size: int, test_tensors: TestTensors,
                             w1: torch.Tensor, w2: torch.Tensor,
                             w1_scale: Optional[torch.Tensor],
                             w2_scale: Optional[torch.Tensor]) -> torch.Tensor:
216

217
218
    test_config = test_tensors.config
    num_experts = test_config.num_experts
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
        expert_map = torch.full((num_experts, ),
                                fill_value=-1,
                                dtype=torch.int32)
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
        return expert_map.to(device=torch.cuda.current_device(),
                             dtype=torch.int32)

    # Make modular kernel
    mk: FusedMoEModularKernel = make_modular_kernel(
234
235
236
237
238
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        num_local_experts=num_local_experts,
        test_tensors=test_tensors)
239

240
241
242
    # Low-Latency kernels can't dispatch scales.
    a1_scale = (None
                if test_config.low_latency else test_tensors.rank_token_scales)
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    out = mk.forward(hidden_states=test_tensors.rank_tokens,
                     w1=w1,
                     w2=w2,
                     topk_weights=test_tensors.topk_weights,
                     topk_ids=test_tensors.topk,
                     inplace=False,
                     activation="silu",
                     global_num_experts=num_experts,
                     expert_map=build_expert_map(),
                     w1_scale=w1_scale,
                     w2_scale=w2_scale,
                     w1_zp=None,
                     w2_zp=None,
                     a1_scale=a1_scale,
                     a2_scale=None,
                     apply_router_weight_on_input=False)
    return out


def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
                topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
                w1_scale: torch.Tensor, w2_scale: torch.Tensor,
                a1_scale: torch.Tensor, block_shape: list[int]):

    return fused_experts(
        hidden_states=a,
        w1=w1,
        w2=w2,
        topk_weights=topk_weights,
        topk_ids=topk_ids,
        inplace=False,
        use_fp8_w8a8=True,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        a1_scale=a1_scale,
        block_shape=block_shape,
        # Make sure this is set to False so we
        # dont end up comparing the same implementation.
        allow_deep_gemm=False)


285
def _test_deepep_deepgemm_moe(
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    pgi: ProcessGroupInfo,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
):
    current_platform.seed_everything(pgi.rank)

    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    w1_scale = w1_scale.to(device=torch.cuda.current_device())
    w2_scale = w2_scale.to(device=torch.cuda.current_device())

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, pgi.rank)
    block_shape = [
        w1.size(1) // w1_scale.size(1),
        w1.size(2) // w1_scale.size(2)
    ]

    with set_current_vllm_config(VllmConfig()):
        # Reference
        triton_moe = triton_impl(a=test_tensors.rank_tokens,
                                 topk_ids=test_tensors.topk,
                                 topk_weights=test_tensors.topk_weights,
                                 w1=w1,
                                 w2=w2,
                                 w1_scale=w1_scale,
                                 w2_scale=w2_scale,
                                 a1_scale=test_tensors.rank_token_scales,
                                 block_shape=block_shape)

        # Slice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]
        w1_scale_ep = w1_scale[e_start:e_end]
        w2_scale_ep = w2_scale[e_start:e_end]

329
        deepep_moe = deepep_deepgemm_moe_impl(
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
            pg,
            pgi,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
        )

    torch.testing.assert_close(
        triton_moe,
        deepep_moe,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (8, 128, 128),
    (8, 128, 512),
    (8, 512, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (129, 128, 256),
    (129, 1024, 2048),
    (222, 1024, 2048),
]

361
362
363
TOPKS = [2, 6]
NUM_EXPERTS = [32]

364
365

@pytest.mark.parametrize("mnk", MNKs)
366
367
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
368
369
370
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
371
372
373
374
375
def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
                                topk: int, world_dp_size: tuple[int, int]):
    """
    Tests for High-Throughput DeepEP + DeepGemm integration.
    """
bnellnm's avatar
bnellnm committed
376
    import deep_gemm
377
378
379
380
381
382
383
384
385
386

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

    block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
    block_size = [block_m, block_m]

387
388
389
390
391
392
    world_size, dp_size = world_dp_size
    config = TestConfig(topk=topk,
                        m=m,
                        k=k,
                        n=n,
                        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
393
                        per_act_token_quant=False,
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
                        block_size=block_size,
                        low_latency=False,
                        use_fp8_dispatch=None)

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
        num_experts, n, k, block_size)

    parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
                    w2, w1_scale, w2_scale)


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
# Fix tests for USE_FP8_DISPATCH=True
USE_FP8_DISPATCH = [False]


@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOPKS)
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@pytest.mark.parametrize("block_size", [[128, 128]])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
bnellnm's avatar
bnellnm committed
426
427
428
429
430
431
432
433
def test_ll_deepep_deepgemm_moe(
    mnk: tuple[int, int, int],
    num_experts: int,
    topk: int,
    use_fp8_dispatch: bool,
    block_size: list[int],
    world_dp_size: tuple[int, int],
):
434
435
436
437
438
439
440
441
442
443
    """
    Tests for Low-Latency DeepEP + DeepGemm integration.
    """

    m, n, k = mnk
    current_platform.seed_everything(7)

    if topk > num_experts:
        pytest.skip(f"Skipping test: topk={topk} > E={num_experts}")

444
445
446
447
448
449
450
    world_size, dp_size = world_dp_size
    config = TestConfig(
        topk=topk,
        m=m,
        k=k,
        n=n,
        num_experts=num_experts,
bnellnm's avatar
bnellnm committed
451
        per_act_token_quant=False,
452
        block_size=block_size,
453
454
        low_latency=True,
        use_fp8_dispatch=use_fp8_dispatch,
455
456
457
458
459
    )

    w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
        num_experts, n, k, block_size)

460
461
    parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1,
                    w2, w1_scale, w2_scale)