test_deepep_moe.py 15.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Test deepep dispatch-combine logic
"""

import dataclasses

import pytest
import torch.distributed
from torch.distributed import ProcessGroup

from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
17
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
18
19
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
20
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21
22
    per_token_group_quant_fp8,
)
23
from vllm.platforms import current_platform
24
from vllm.utils.import_utils import has_deep_ep
25

26
from ...utils import multi_gpu_test
bnellnm's avatar
bnellnm committed
27
from .parallel_utils import ProcessGroupInfo, parallel_launch
28

29
if has_deep_ep():
30
    from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
31
32
        DeepEPHTPrepareAndFinalize,
    )
33
    from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import (
34
35
        DeepEPLLPrepareAndFinalize,
    )
36

bnellnm's avatar
bnellnm committed
37
    from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
38
39

requires_deep_ep = pytest.mark.skipif(
40
    not has_deep_ep(),
41
42
43
44
45
46
47
    reason="Requires deep_ep kernels",
)

MAX_TOKENS_PER_RANK = 64


def make_weights(
48
    e, n, k, dtype
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Return weights w1, w2, w1_scale, w2_scale
    """
    if dtype in [torch.float16, torch.bfloat16]:
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
        return w1, w2, None, None

    # per-out-channel weight quantization
    assert dtype == torch.float8_e4m3fn
    w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16)
    w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16)

    n_b_scales = 2 * n
    k_b_scales = k
    w1_q = torch.empty_like(w1, dtype=dtype)
    w2_q = torch.empty_like(w2, dtype=dtype)
67
68
    w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
    w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
69
70
    for expert in range(e):
        w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
71
72
            w1[expert], use_per_token_if_dynamic=True
        )
73
        w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
74
75
            w2[expert], use_per_token_if_dynamic=True
        )
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    return w1_q, w2_q, w1_scale, w2_scale


@dataclasses.dataclass
class TestConfig:
    dtype: torch.dtype
    topk: int
    m: int
    k: int
    n: int
    num_experts: int


@dataclasses.dataclass
class TestTensors:
    rank_tokens: torch.Tensor  # all ranks make this many tokens
92
    rank_token_scales: torch.Tensor | None
93
94
95
96
97
98
99
100
    topk: torch.Tensor
    topk_weights: torch.Tensor
    config: TestConfig

    @staticmethod
    def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors":
        # TODO (varun) - check that float16 works ?
        assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn]
101
102
103
104
105
106
        token_dtype = (
            torch.bfloat16 if config.dtype == torch.float8_e4m3fn else config.dtype
        )
        rank_tokens = (
            torch.randn((config.m, config.k), device="cuda", dtype=token_dtype) / 10
        )
107
108
        rank_token_scales = None

109
110
111
112
113
114
115
116
117
118
119
        topk = torch.randint(
            low=0, high=config.num_experts, size=(config.m, config.topk), device="cuda"
        ).to(dtype=torch.int64)
        topk_weights = torch.randn(topk.shape, dtype=torch.float32, device="cuda")
        return TestTensors(
            rank_tokens=rank_tokens,
            rank_token_scales=rank_token_scales,
            topk=topk,
            topk_weights=topk_weights,
            config=config,
        )
120
121


bnellnm's avatar
bnellnm committed
122
123
124
125
126
127
128
129
def make_modular_kernel(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    low_latency_mode: bool,
    hidden_size: int,
    dp_size: int,
    num_experts: int,
    num_local_experts: int,
130
    q_dtype: torch.dtype | None,
bnellnm's avatar
bnellnm committed
131
    use_fp8_dispatch: bool,
132
    quant_config: FusedMoEQuantConfig,
bnellnm's avatar
bnellnm committed
133
) -> FusedMoEModularKernel:
134
135
    ht_args: DeepEPHTArgs | None = None
    ll_args: DeepEPLLArgs | None = None
136
137

    if low_latency_mode:
138
139
140
141
142
143
        ll_args = DeepEPLLArgs(
            max_tokens_per_rank=MAX_TOKENS_PER_RANK,
            hidden_size=hidden_size,
            num_experts=num_experts,
            use_fp8_dispatch=use_fp8_dispatch,
        )
144
145
    else:
        assert not use_fp8_dispatch, (
146
147
            "FP8 Dispatch is valid only for low-latency kernels"
        )
148
149
        ht_args = DeepEPHTArgs(num_local_experts=num_local_experts)

150
151
152
153
154
155
156
157
    a2a: DeepEPHTPrepareAndFinalize | DeepEPLLPrepareAndFinalize = make_deepep_a2a(
        pg=pg,
        pgi=pgi,
        dp_size=dp_size,
        q_dtype=q_dtype,
        block_shape=None,
        deepep_ht_args=ht_args,
        deepep_ll_args=ll_args,
158
    )
159

160
161
    num_dispatchers = pgi.world_size // dp_size

162
    if low_latency_mode:
163
        assert not quant_config.per_act_token_quant, "not supported in ll mode"
164
165
        fused_experts = BatchedTritonExperts(
            max_num_tokens=MAX_TOKENS_PER_RANK,
166
            num_dispatchers=num_dispatchers,
167
            quant_config=quant_config,
bnellnm's avatar
bnellnm committed
168
        )
169
    else:
170
        fused_experts = TritonExperts(quant_config=quant_config)
171

172
    mk = FusedMoEModularKernel(prepare_finalize=a2a, fused_experts=fused_experts)
173
174
175
    return mk


bnellnm's avatar
bnellnm committed
176
177
178
179
180
181
182
183
def deep_ep_moe_impl(
    pg: ProcessGroup,
    pgi: ProcessGroupInfo,
    low_latency_mode: bool,
    dp_size: int,
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
184
185
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
bnellnm's avatar
bnellnm committed
186
187
188
189
    num_experts: int,
    use_fp8_dispatch: bool,
    per_act_token_quant: bool,
) -> torch.Tensor:
190
191
192
193
    num_local_experts = w1.size(0)

    def build_expert_map():
        num_local_experts = w1.size(0)
194
        expert_map = torch.full((num_experts,), fill_value=-1, dtype=torch.int32)
195
196
197
        s = pgi.rank * num_local_experts
        e = s + num_local_experts
        expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
198
        return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    hidden_size = test_tensors.rank_tokens.size(1)
    is_quantized = w1.dtype == torch.float8_e4m3fn
    q_dtype = None
    if is_quantized:
        q_dtype = torch.float8_e4m3fn

    out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
    total_num_tokens = test_tensors.rank_tokens.size(0)

    def process_chunk(chunk_start, chunk_end, skip_result_store=False):
        rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end]
        topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end]
        topk_chunk = test_tensors.topk[chunk_start:chunk_end]
        rank_token_scales_chunk = test_tensors.rank_token_scales
214
215
216
217
        if (
            rank_token_scales_chunk is not None
            and rank_token_scales_chunk.size(0) == total_num_tokens
        ):
218
            # per act token
219
            rank_token_scales_chunk = rank_token_scales_chunk[chunk_start:chunk_end]
220

221
222
223
224
225
226
227
228
229
230
        quant_config = FusedMoEQuantConfig.make(
            q_dtype,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            per_act_token_quant=per_act_token_quant,
            a1_scale=rank_token_scales_chunk,
        )

        # Make modular kernel
        mk: FusedMoEModularKernel = make_modular_kernel(
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            pg,
            pgi,
            low_latency_mode,
            hidden_size,
            dp_size,
            num_experts,
            num_local_experts,
            q_dtype,
            use_fp8_dispatch,
            quant_config,
        )

        out = mk.forward(
            hidden_states=rank_tokens_chunk,
            w1=w1,
            w2=w2,
            topk_weights=topk_weights_chunk,
            topk_ids=topk_chunk,
            inplace=False,
            activation="silu",
            global_num_experts=num_experts,
            expert_map=build_expert_map(),
            apply_router_weight_on_input=False,
        )
255
256

        if not skip_result_store:
257
            out_hidden_states[chunk_start:chunk_end, :].copy_(out, non_blocking=True)
258

259
260
261
    max_num_tokens_per_dp = (
        MAX_TOKENS_PER_RANK if low_latency_mode else total_num_tokens
    )
262
263
264
265
266
267
268
269

    for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp):
        chunk_start = chunk_start_
        chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens)
        # clamp start and end
        chunk_start = min(chunk_start, total_num_tokens - 1)
        chunk_end = min(chunk_end, total_num_tokens)

270
271
272
        process_chunk(
            chunk_start, chunk_end, skip_result_store=chunk_start_ >= total_num_tokens
        )
273
274
275
276

    return out_hidden_states


bnellnm's avatar
bnellnm committed
277
278
279
280
def torch_moe_impl(
    test_tensors: TestTensors,
    w1: torch.Tensor,
    w2: torch.Tensor,
281
282
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
bnellnm's avatar
bnellnm committed
283
284
285
    using_fp8_dispatch: bool,
    per_act_token_quant: bool,
):
286
287
288
289
290
    a, topk_ids, topk_weights = (
        test_tensors.rank_tokens,
        test_tensors.topk,
        test_tensors.topk_weights,
    )
291
292
293
294
    if using_fp8_dispatch:
        # The DeepEP implementation is requested to dispatch using FP8.
        # For numerical stability for testing, emulate the fp8 dispatch by
        # blockwise quant and de-quant.
bnellnm's avatar
bnellnm committed
295
        assert not per_act_token_quant
296
297
        a = test_tensors.rank_tokens
        aq, aq_scale = per_token_group_quant_fp8(a, 128)
298
299
300
301
302
        a = (
            (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
            .view(a.shape)
            .to(a.dtype)
        )
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

    is_quantized = w1.dtype == torch.float8_e4m3fn
    a_dtype = a.dtype
    if is_quantized:
        w1 = w1.to(dtype=torch.float32) * w1_scale
        w2 = w2.to(dtype=torch.float32) * w2_scale
        a = a.to(dtype=torch.float32)

    m, _ = a.shape
    topk = topk_ids.size(1)
    out = torch.zeros_like(a)

    for i in range(m):
        a_i = a[i]
        o_i = out[i]
        for j in range(topk):
            e = topk_ids[i][j]
            e_w = topk_weights[i][j]
            w1_e = w1[e]
            w2_e = w2[e]
323
324
325
            o_i += (
                SiluAndMul()(a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)
            ) * e_w
326
327
328
329
330
331
332
333
334
335
336
337
338
339

    if is_quantized:
        out = out.to(dtype=a_dtype)

    return out


def _deep_ep_moe(
    pgi: ProcessGroupInfo,
    low_latency_mode: bool,
    dp_size: int,
    config: TestConfig,
    w1: torch.Tensor,
    w2: torch.Tensor,
340
341
    w1_scale: torch.Tensor | None,
    w2_scale: torch.Tensor | None,
342
    use_fp8_dispatch: bool,
bnellnm's avatar
bnellnm committed
343
    per_act_token_quant: bool,
344
345
346
):
    if not low_latency_mode:
        assert not use_fp8_dispatch, (
347
348
            "FP8 dispatch interface is available only in low-latency mode"
        )
349
350
351
352
353
354

    is_quantized = w1.dtype == torch.float8_e4m3fn
    w1 = w1.to(device=torch.cuda.current_device())
    w2 = w2.to(device=torch.cuda.current_device())
    if is_quantized:
        w1_scale = w1_scale.to(  # type: ignore
355
356
            device=torch.cuda.current_device()
        )
357
        w2_scale = w2_scale.to(  # type: ignore
358
359
            device=torch.cuda.current_device()
        )
360
361
362
363
364
365

    pg = torch.distributed.new_group(list(range(pgi.world_size)))
    test_tensors = TestTensors.make(config, low_latency_mode)

    with set_current_vllm_config(VllmConfig()):
        # Reference
366
367
368
369
370
371
372
373
374
        torch_combined = torch_moe_impl(
            test_tensors,
            w1,
            w2,
            w1_scale,
            w2_scale,
            use_fp8_dispatch,
            per_act_token_quant,
        )
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

        # Splice experts for this rank.
        num_local_experts = config.num_experts // pgi.world_size
        e_start = num_local_experts * pgi.rank
        e_end = e_start + num_local_experts
        w1_ep = w1[e_start:e_end]
        w2_ep = w2[e_start:e_end]

        w1_scale_ep, w2_scale_ep = None, None
        if is_quantized:
            w1_scale_ep = w1_scale[e_start:e_end]  # type: ignore
            w2_scale_ep = w2_scale[e_start:e_end]  # type: ignore
        deepep_combined = deep_ep_moe_impl(
            pg,
            pgi,
            low_latency_mode,
            dp_size,
            test_tensors,
            w1_ep,
            w2_ep,
            w1_scale_ep,
            w2_scale_ep,
            config.num_experts,
            use_fp8_dispatch,
bnellnm's avatar
bnellnm committed
399
            per_act_token_quant,
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        )

    torch.testing.assert_close(
        torch_combined,
        deepep_combined,
        atol=6e-2,
        rtol=6e-2,
    )


MNKs = [
    (1, 128, 128),
    (2, 128, 512),
    (3, 1024, 2048),
    (32, 128, 1024),
    (45, 512, 2048),
    (64, 1024, 1024),
    (222, 1024, 2048),
]

DTYPES = [torch.bfloat16, torch.float8_e4m3fn]


@pytest.mark.parametrize("dtype", DTYPES)
424
@pytest.mark.parametrize("m,n,k", MNKs)
425
426
427
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
bnellnm's avatar
bnellnm committed
428
@pytest.mark.parametrize("per_act_token_quant", [False, True])
429
@multi_gpu_test(num_gpus=2)
430
@requires_deep_ep
bnellnm's avatar
bnellnm committed
431
432
def test_deep_ep_moe(
    dtype: torch.dtype,
433
434
435
    m: int,
    n: int,
    k: int,
bnellnm's avatar
bnellnm committed
436
437
438
439
440
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
    per_act_token_quant: bool,
):
441
442
443
444
445
    low_latency_mode = False
    use_fp8_dispatch = False

    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
446
    config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
447
448
449

    w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)

450
451
452
453
454
455
456
457
458
459
460
461
462
    parallel_launch(
        world_size,
        _deep_ep_moe,
        low_latency_mode,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
        use_fp8_dispatch,
        per_act_token_quant,
    )
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


MNKs = [
    (1, 128, 2560),
    (2, 128, 2560),
    (3, 1024, 2560),
    (32, 128, 2560),
    (45, 512, 2560),
    (64, 1024, 2560),
    (222, 1024, 2560),
]
DTYPES = [torch.float8_e4m3fn, torch.bfloat16]
USE_FP8_DISPATCH = [True, False]


@pytest.mark.parametrize("dtype", DTYPES)
479
@pytest.mark.parametrize("m,n,k", MNKs)
480
481
482
483
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
484
@multi_gpu_test(num_gpus=2)
485
@requires_deep_ep
486
487
488
489
490
491
492
493
494
495
def test_low_latency_deep_ep_moe(
    dtype: torch.dtype,
    m: int,
    n: int,
    k: int,
    num_experts: int,
    topk: int,
    world_dp_size: tuple[int, int],
    use_fp8_dispatch: bool,
):
496
497
    low_latency_mode = True

498
    if low_latency_mode and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES:
499
500
501
502
503
504
505
        pytest.skip(
            f"Skipping test as hidden size {k} is not in list of supported "
            f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}"
        )

    current_platform.seed_everything(7)
    world_size, dp_size = world_dp_size
506
    config = TestConfig(dtype=dtype, topk=topk, m=m, k=k, n=n, num_experts=num_experts)
507
508
509

    w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)

510
511
512
513
514
515
516
517
518
519
520
521
522
    parallel_launch(
        world_size,
        _deep_ep_moe,
        low_latency_mode,
        dp_size,
        config,
        w1,
        w2,
        w1_scale,
        w2_scale,
        use_fp8_dispatch,
        False,
    )