test_cutlass_moe.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import dataclasses
5
from math import prod
6

7
8
9
import pytest
import torch

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
from tests.kernels.moe.utils import make_dummy_moe_config
12
13
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
14
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
15
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
16
17
18
from vllm.model_executor.layers.fused_moe.all2all_utils import (
    maybe_make_prepare_finalize,
)
19
from vllm.model_executor.layers.fused_moe.config import (
20
    FUSED_MOE_UNQUANTIZED_CONFIG,
21
    FusedMoEQuantConfig,
22
23
    fp8_w8a8_moe_quant_config,
)
24
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
25
    CutlassExpertsFp8,
26
27
28
    run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
29
from vllm.platforms import current_platform
30
from vllm.utils.torch_utils import set_random_seed
31
32
33
34

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]

35
36
37
38
MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 3072, 1024),
    (2, 3072, 1536),
39
    (7, 3072, 1536),
40
41
42
43
44
45
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (224, 1024, 1024),
    (224, 3072, 1024),
    (224, 3072, 1536),
46
47
    (32768, 1024, 1024),
    # These sizes trigger wrong answers.
48
49
    # (7232, 2048, 5120),
    # (40000, 2048, 5120),
50
51
]

52
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
53

54
55
56
57
58
59
60
61
62
63
64
65

@dataclasses.dataclass
class MOETensors:
    a: torch.Tensor
    w1: torch.Tensor
    w2: torch.Tensor
    ab_strides1: torch.Tensor
    c_strides1: torch.Tensor
    ab_strides2: torch.Tensor
    c_strides2: torch.Tensor

    @staticmethod
66
67
68
    def make_moe_tensors(
        m: int, k: int, n: int, e: int, dtype: torch.dtype
    ) -> "MOETensors":
69
70
71
        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
72
73
74
75
76
77
78
79
80
81
82
83
84
        ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        return MOETensors(
            a=a,
            w1=w1,
            w2=w2,
            ab_strides1=ab_strides1,
            c_strides1=c_strides1,
            ab_strides2=ab_strides2,
            c_strides2=c_strides2,
        )
85
86
87
88
89


@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
    # quantized
90
91
92
93
94
95
    a_q: torch.Tensor | None = None  # a -> a_q
    w1_q: torch.Tensor | None = None  # w1 -> w1_q
    w2_q: torch.Tensor | None = None  # w2 -> w2_q
    a_scale: torch.Tensor | None = None
    w1_scale: torch.Tensor | None = None
    w2_scale: torch.Tensor | None = None
96
    # dequantized
97
98
99
    a_d: torch.Tensor | None = None  # a -> a_q -> a_d
    w1_d: torch.Tensor | None = None  # w1 -> w1_q -> w1_d
    w2_d: torch.Tensor | None = None  # w2 -> w2_q -> w2_d
100
101

    @staticmethod
102
103
104
    def make_moe_tensors_8bit(
        m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
    ) -> "MOETensors8Bit":
105
106
        dtype = torch.half
        q_dtype = torch.float8_e4m3fn
107

108
109
110
111
112
113
        moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)

        # a -> a_q, w1 -> w1_q, w2 -> w2_q
        n_b_scales = 2 * n if per_out_channel else 1
        k_b_scales = k if per_out_channel else 1
        # Get the right scale for tests.
bnellnm's avatar
bnellnm committed
114
        a_q, a_scale = ops.scaled_fp8_quant(
115
116
            moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
        )
bnellnm's avatar
bnellnm committed
117

118
119
120
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
        w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)

121
122
        w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
123
124
        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
125
126
                moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
            )
127
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
128
129
                moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
            )
130
131
132
133
134
135
136
137
138

        # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
        a_d = a_q.float().mul(a_scale).to(dtype)
        w1_d = torch.empty_like(moe_tensors_fp16.w1)
        w2_d = torch.empty_like(moe_tensors_fp16.w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        return MOETensors8Bit(
            a=moe_tensors_fp16.a,
            w1=moe_tensors_fp16.w1,
            w2=moe_tensors_fp16.w2,
            ab_strides1=moe_tensors_fp16.ab_strides1,
            c_strides1=moe_tensors_fp16.c_strides1,
            ab_strides2=moe_tensors_fp16.ab_strides2,
            c_strides2=moe_tensors_fp16.c_strides2,
            a_q=a_q,
            w1_q=w1_q,
            w2_q=w2_q,
            a_scale=a_scale,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a_d=a_d,
            w1_d=w1_d,
            w2_d=w2_d,
        )


def run_with_expert_maps(
160
161
162
163
    num_experts: int,
    num_local_experts: int,
    quant_config: FusedMoEQuantConfig,
    **cutlass_moe_kwargs,
164
):
165
166
    def slice_experts():
        slice_params = [
167
168
            "w1",
            "w2",
169
170
171
172
173
174
175
176
177
178
179
180
181
        ]
        full_tensors = {
            k: v
            for k, v in cutlass_moe_kwargs.items()
            if k in slice_params and k in cutlass_moe_kwargs
        }

        for i in range(0, num_experts, num_local_experts):
            s, e = i, i + num_local_experts

            # make expert map
            expert_map = [-1] * num_experts
            expert_map[s:e] = list(range(num_local_experts))
182
            expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
183
184
185
186
187
188
189

            # update cutlass moe arg with expert_map
            cutlass_moe_kwargs["expert_map"] = expert_map
            # update cutlass moe arg tensors
            for k, t in full_tensors.items():
                cutlass_moe_kwargs[k] = t[s:e]

190
191
192
193
            new_quant_config = copy.deepcopy(quant_config)
            new_quant_config._w1.scale = quant_config.w1_scale[s:e]
            new_quant_config._w2.scale = quant_config.w2_scale[s:e]

194
195
196
197
            yield cutlass_moe_kwargs, new_quant_config

    out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
    for kwargs, new_quant_config in slice_experts():
198
199
        w2 = kwargs["w2"]
        a = kwargs["hidden_states"]
200
201
202
203
204
205
206
207
208
209
210
211
212
        moe_config = make_dummy_moe_config(
            num_experts=w2.shape[0],
            hidden_dim=w2.shape[1],
            intermediate_size_per_partition=w2.shape[2],
            in_dtype=a.dtype,
        )
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=new_quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
213
            CutlassExpertsFp8(
214
                moe_config=moe_config,
215
216
                quant_config=new_quant_config,
            ),
217
            inplace=False,
218
        )
219
        out_tensor = out_tensor + kernel.apply(**kwargs)
220
221
222
223

    return out_tensor


224
225
226
227
228
229
def run_8_bit(
    moe_tensors: MOETensors8Bit,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
230
    num_local_experts: int | None = None,
231
232
233
234
235
236
237
238
239
240
241
) -> torch.Tensor:
    assert not any(
        [
            t is None
            for t in [
                moe_tensors.w1_q,
                moe_tensors.w2_q,
                moe_tensors.w1_scale,
                moe_tensors.w2_scale,
                moe_tensors.a_scale,
            ]
242
        ]
243
    )
244

245
246
247
248
249
250
251
252
253
254
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=moe_tensors.w1_scale,
        w2_scale=moe_tensors.w2_scale,
        per_act_token_quant=per_act_token,
        per_out_ch_quant=per_out_ch,
        # Set to moe_tensors.a_scale iff static scales + per tensor.
        # This is not currently being tested.
        a1_scale=None,
    )

255
    kwargs = {
256
257
258
        "hidden_states": moe_tensors.a,
        "w1": moe_tensors.w1_q,  # type: ignore[union-attr]
        "w2": moe_tensors.w2_q,  # type: ignore[union-attr]
259
260
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
261
262
263
264
        "global_num_experts": moe_tensors.w1_q.shape[0],  # type: ignore[union-attr]
        "activation": MoEActivation.SILU,
        "expert_map": None,
        "apply_router_weight_on_input": False,
265
266
    }

267
    num_experts = moe_tensors.w1.size(0)  # type: ignore[attr-defined]
268
269
    with_ep = num_local_experts is not None or num_local_experts == num_experts
    if not with_ep:
270
271
272
273
274
275
276
277
278
279
280
281
282
        moe_config = make_dummy_moe_config(
            num_experts=moe_tensors.w2_q.shape[0],  # type: ignore[union-attr]
            hidden_dim=moe_tensors.w2_q.shape[1],  # type: ignore[union-attr]
            intermediate_size_per_partition=moe_tensors.w2_q.shape[2],  # type: ignore[union-attr]
            in_dtype=moe_tensors.a.dtype,
        )
        kernel = mk.FusedMoEKernel(
            maybe_make_prepare_finalize(
                moe=moe_config,
                quant_config=quant_config,
                allow_new_interface=True,
                use_monolithic=False,
            ),
283
            CutlassExpertsFp8(
284
                moe_config=moe_config,
285
286
                quant_config=quant_config,
            ),
287
            inplace=False,
288
        )
289
        return kernel.apply(**kwargs)
290
291
292
293
294

    assert num_local_experts is not None
    return run_with_expert_maps(
        num_experts,
        num_local_experts,  # type: ignore[arg-type]
295
        quant_config,
296
297
        **kwargs,
    )
298
299
300


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
301
302
303
304
305
306
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
307
308
309
310
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
311
def test_cutlass_moe_8_bit_no_graph(
312
313
314
315
316
317
318
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
319
    monkeypatch,
320
    workspace_init,
321
    ep_size: int | None = None,
322
):
323
    set_random_seed(7)
324
    with set_current_vllm_config(vllm_config):
325
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
326

327
        score = torch.randn((m, e), device="cuda", dtype=torch.half)
328
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
329

330
331
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
332
333

        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
334
335
336
        triton_output = fused_experts(
            mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
        )
337

338
339
340
341
342
        if ep_size is not None:
            assert e % ep_size == 0, "Cannot distribute experts evenly"
            number_local_experts = e // ep_size
        else:
            number_local_experts = None
343

344
345
346
        cutlass_output = run_8_bit(
            mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
        )
347

bnellnm's avatar
bnellnm committed
348
349
        # Note 5.5 only needed for larger problem sizes, 5 works ok for
        # the rest.
350
351
352
        torch.testing.assert_close(
            triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
        )
353
354


355
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
356
357
358
359
360
361
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
362
363
364
365
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
366
def test_cutlass_moe_8_bit_cuda_graph(
367
368
369
370
371
372
373
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
374
    monkeypatch,
375
    workspace_init,
376
):
377
    set_random_seed(7)
378
    with set_current_vllm_config(vllm_config):
379
380
        dtype = torch.half

381
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
382
383

        score = torch.randn((m, e), device="cuda", dtype=dtype)
384
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
385

386
387
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
388
        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
389
390
391
        triton_output = fused_experts(
            mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
        )
392
393
394
395

        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
396
397
398
            cutlass_output = run_8_bit(
                mt, topk_weights, topk_ids, per_act_token, per_out_ch
            )
399

400
        torch.accelerator.synchronize()
401
        graph.replay()
402
        torch.accelerator.synchronize()
403

404
        torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
405
406
407
408
409
410
411
412
413
414
415
416


@pytest.mark.parametrize("m", [64])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [16])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
417
418
419
420
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
421
422
423
424
425
426
427
428
429
def test_cutlass_moe_8_bit_EP(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
430
    monkeypatch,
431
    workspace_init,
432
):
433
    test_cutlass_moe_8_bit_no_graph(
434
435
436
437
438
439
440
441
442
443
        m,
        n,
        k,
        e,
        topk,
        per_act_token,
        per_out_channel,
        monkeypatch,
        workspace_init,
        ep_size,
444
    )
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


LARGE_MNK_FACTORS = [
    (1, 8192, 5120, 31),
    (32768, 1024, 1024, 16),
    (65536, 512, 1024, 16),
]


@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
461
462
463
464
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
465
466
467
468
469
470
471
472
473
474
def test_cutlass_moe_8_bit_EP_large(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
    monkeypatch,
475
    workspace_init,
476
):
477
    test_cutlass_moe_8_bit_no_graph(
478
479
480
481
482
483
484
485
486
487
        m,
        n,
        k,
        e,
        topk,
        per_act_token,
        per_out_channel,
        monkeypatch,
        workspace_init,
        ep_size,
488
    )
489
490
491
492
493
494
495
496
497


@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
498
499
500
501
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
502
503
504
505
506
507
508
509
510
def test_run_cutlass_moe_fp8(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
511
    workspace_init,
512
):
513
    set_random_seed(7)
514
    with set_current_vllm_config(vllm_config):
515
516
517
        mt = MOETensors8Bit.make_moe_tensors_8bit(
            m, k, n, e, per_act_token, per_out_channel
        )
518
519

        score = torch.randn((m, e), device="cuda", dtype=torch.half)
520
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
521
522
523
524
525
526
527
        # we want to make sure there is at least one token that's generated in
        # this expert shard and at least one token that's NOT generated in this
        # expert shard
        topk_ids[0][0] = -1
        topk_ids[0][1] = 1

        workspace13_shape = (m * topk, max(2 * n, k))
528
529
        workspace2_shape = (m * topk, max(n, k))
        output_shape = (m, k)
530

531
532
533
534
535
536
        workspace13 = torch.empty(
            prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
        )
        workspace2 = torch.empty(
            prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
        )
537
538
539
540
541
542
543

        num_local_experts = e // ep_size
        start, end = 0, num_local_experts
        expert_map = [-1] * e
        expert_map[start:end] = list(range(num_local_experts))
        expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

544
545
546
547
        ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
548

549
        activation = MoEActivation.SILU
550
551
552
        a1q, a1q_scale = moe_kernel_quantize_input(
            mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
        )
553
554
        global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
        func = lambda output: run_cutlass_moe_fp8(
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
            output,
            a1q,
            mt.w1_q,
            mt.w2_q,
            topk_ids,
            activation,
            global_num_experts,
            expert_map,
            mt.w1_scale,
            mt.w2_scale,
            a1q_scale,
            None,
            ab_strides1,
            ab_strides2,
            c_strides1,
            c_strides2,
            workspace13,
            workspace2,
            None,
            mt.a.dtype,
            per_act_token,
            per_out_channel,
            False,
            topk_weights,
        )
580
581

        workspace13.random_()
582
583
584
        output_random_workspace = torch.empty(
            output_shape, device="cuda", dtype=mt.a.dtype
        )
585
586
587
        func(output_random_workspace)

        workspace13.fill_(0)
588
589
590
        output_zero_workspace = torch.zeros(
            output_shape, device="cuda", dtype=mt.a.dtype
        )
591
592
        func(output_zero_workspace)

593
594
595
        torch.testing.assert_close(
            output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
        )