test_cutlass_moe.py 18.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import dataclasses
5
from math import prod
6

7
8
9
import pytest
import torch

10
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
11
12
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
13
from vllm.model_executor.layers.fused_moe.config import (
14
    FUSED_MOE_UNQUANTIZED_CONFIG,
15
    FusedMoEQuantConfig,
16
17
    fp8_w8a8_moe_quant_config,
)
18
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
19
    CutlassExpertsFp8,
20
21
22
    run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
23
24
25
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
    MoEPrepareAndFinalizeNoEP,
)
26
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
27
from vllm.platforms import current_platform
28
from vllm.utils.torch_utils import set_random_seed
29
30
31
32

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]

33
34
35
36
MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 3072, 1024),
    (2, 3072, 1536),
37
    (7, 3072, 1536),
38
39
40
41
42
43
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (224, 1024, 1024),
    (224, 3072, 1024),
    (224, 3072, 1536),
44
45
    (32768, 1024, 1024),
    # These sizes trigger wrong answers.
46
47
    # (7232, 2048, 5120),
    # (40000, 2048, 5120),
48
49
]

50
vllm_config = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
51

52
53
54
55
56
57
58
59
60
61
62
63

@dataclasses.dataclass
class MOETensors:
    a: torch.Tensor
    w1: torch.Tensor
    w2: torch.Tensor
    ab_strides1: torch.Tensor
    c_strides1: torch.Tensor
    ab_strides2: torch.Tensor
    c_strides2: torch.Tensor

    @staticmethod
64
65
66
    def make_moe_tensors(
        m: int, k: int, n: int, e: int, dtype: torch.dtype
    ) -> "MOETensors":
67
68
69
        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
70
71
72
73
74
75
76
77
78
79
80
81
82
        ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        return MOETensors(
            a=a,
            w1=w1,
            w2=w2,
            ab_strides1=ab_strides1,
            c_strides1=c_strides1,
            ab_strides2=ab_strides2,
            c_strides2=c_strides2,
        )
83
84
85
86
87


@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
    # quantized
88
89
90
91
92
93
    a_q: torch.Tensor | None = None  # a -> a_q
    w1_q: torch.Tensor | None = None  # w1 -> w1_q
    w2_q: torch.Tensor | None = None  # w2 -> w2_q
    a_scale: torch.Tensor | None = None
    w1_scale: torch.Tensor | None = None
    w2_scale: torch.Tensor | None = None
94
    # dequantized
95
96
97
    a_d: torch.Tensor | None = None  # a -> a_q -> a_d
    w1_d: torch.Tensor | None = None  # w1 -> w1_q -> w1_d
    w2_d: torch.Tensor | None = None  # w2 -> w2_q -> w2_d
98
99

    @staticmethod
100
101
102
    def make_moe_tensors_8bit(
        m: int, k: int, n: int, e: int, per_act_token: bool, per_out_channel: bool
    ) -> "MOETensors8Bit":
103
104
        dtype = torch.half
        q_dtype = torch.float8_e4m3fn
105

106
107
108
109
110
111
        moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)

        # a -> a_q, w1 -> w1_q, w2 -> w2_q
        n_b_scales = 2 * n if per_out_channel else 1
        k_b_scales = k if per_out_channel else 1
        # Get the right scale for tests.
bnellnm's avatar
bnellnm committed
112
        a_q, a_scale = ops.scaled_fp8_quant(
113
114
            moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token
        )
bnellnm's avatar
bnellnm committed
115

116
117
118
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
        w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)

119
120
        w1_scale = torch.empty((e, n_b_scales, 1), device="cuda", dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1), device="cuda", dtype=torch.float32)
121
122
        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
123
124
                moe_tensors_fp16.w1[expert], use_per_token_if_dynamic=per_out_channel
            )
125
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
126
127
                moe_tensors_fp16.w2[expert], use_per_token_if_dynamic=per_out_channel
            )
128
129
130
131
132
133
134
135
136

        # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
        a_d = a_q.float().mul(a_scale).to(dtype)
        w1_d = torch.empty_like(moe_tensors_fp16.w1)
        w2_d = torch.empty_like(moe_tensors_fp16.w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        return MOETensors8Bit(
            a=moe_tensors_fp16.a,
            w1=moe_tensors_fp16.w1,
            w2=moe_tensors_fp16.w2,
            ab_strides1=moe_tensors_fp16.ab_strides1,
            c_strides1=moe_tensors_fp16.c_strides1,
            ab_strides2=moe_tensors_fp16.ab_strides2,
            c_strides2=moe_tensors_fp16.c_strides2,
            a_q=a_q,
            w1_q=w1_q,
            w2_q=w2_q,
            a_scale=a_scale,
            w1_scale=w1_scale,
            w2_scale=w2_scale,
            a_d=a_d,
            w1_d=w1_d,
            w2_d=w2_d,
        )


def run_with_expert_maps(
158
159
160
161
    num_experts: int,
    num_local_experts: int,
    quant_config: FusedMoEQuantConfig,
    **cutlass_moe_kwargs,
162
):
163
164
    def slice_experts():
        slice_params = [
165
166
            "w1",
            "w2",
167
168
169
170
171
172
173
174
175
176
177
178
179
        ]
        full_tensors = {
            k: v
            for k, v in cutlass_moe_kwargs.items()
            if k in slice_params and k in cutlass_moe_kwargs
        }

        for i in range(0, num_experts, num_local_experts):
            s, e = i, i + num_local_experts

            # make expert map
            expert_map = [-1] * num_experts
            expert_map[s:e] = list(range(num_local_experts))
180
            expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
181
182
183
184
185
186
187

            # update cutlass moe arg with expert_map
            cutlass_moe_kwargs["expert_map"] = expert_map
            # update cutlass moe arg tensors
            for k, t in full_tensors.items():
                cutlass_moe_kwargs[k] = t[s:e]

188
189
190
191
            new_quant_config = copy.deepcopy(quant_config)
            new_quant_config._w1.scale = quant_config.w1_scale[s:e]
            new_quant_config._w2.scale = quant_config.w2_scale[s:e]

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            yield cutlass_moe_kwargs, new_quant_config

    out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
    for kwargs, new_quant_config in slice_experts():
        kernel = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(),
            CutlassExpertsFp8(
                out_dtype=kwargs["hidden_states"].dtype,
                # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
                e=kwargs["w2"].shape[0],  # type: ignore[union-attr]
                n=kwargs["w2"].shape[2],  # type: ignore[union-attr]
                k=kwargs["w2"].shape[1],  # type: ignore[union-attr]
                quant_config=new_quant_config,
                device="cuda",
            ),
        )
        out_tensor = out_tensor + kernel(**kwargs)
209
210
211
212

    return out_tensor


213
214
215
216
217
218
def run_8_bit(
    moe_tensors: MOETensors8Bit,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    per_act_token: bool,
    per_out_ch: bool,
219
    num_local_experts: int | None = None,
220
221
222
223
224
225
226
227
228
229
230
) -> torch.Tensor:
    assert not any(
        [
            t is None
            for t in [
                moe_tensors.w1_q,
                moe_tensors.w2_q,
                moe_tensors.w1_scale,
                moe_tensors.w2_scale,
                moe_tensors.a_scale,
            ]
231
        ]
232
    )
233

234
235
236
237
238
239
240
241
242
243
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=moe_tensors.w1_scale,
        w2_scale=moe_tensors.w2_scale,
        per_act_token_quant=per_act_token,
        per_out_ch_quant=per_out_ch,
        # Set to moe_tensors.a_scale iff static scales + per tensor.
        # This is not currently being tested.
        a1_scale=None,
    )

244
    kwargs = {
245
246
247
        "hidden_states": moe_tensors.a,
        "w1": moe_tensors.w1_q,  # type: ignore[union-attr]
        "w2": moe_tensors.w2_q,  # type: ignore[union-attr]
248
249
        "topk_weights": topk_weights,
        "topk_ids": topk_ids,
250
251
252
253
254
    }

    num_experts = moe_tensors.w1.size(0)
    with_ep = num_local_experts is not None or num_local_experts == num_experts
    if not with_ep:
255
256
257
258
259
260
261
262
263
264
265
266
267
        kernel = mk.FusedMoEModularKernel(
            MoEPrepareAndFinalizeNoEP(),
            CutlassExpertsFp8(
                out_dtype=moe_tensors.a.dtype,
                # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
                e=moe_tensors.w2_q.shape[0],  # type: ignore[union-attr]
                n=moe_tensors.w2_q.shape[2],  # type: ignore[union-attr]
                k=moe_tensors.w2_q.shape[1],  # type: ignore[union-attr]
                quant_config=quant_config,
                device="cuda",
            ),
        )
        return kernel(**kwargs)
268
269
270
271
272

    assert num_local_experts is not None
    return run_with_expert_maps(
        num_experts,
        num_local_experts,  # type: ignore[arg-type]
273
        quant_config,
274
275
        **kwargs,
    )
276
277
278


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
279
280
281
282
283
284
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
285
286
287
288
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
289
def test_cutlass_moe_8_bit_no_graph(
290
291
292
293
294
295
296
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
297
    monkeypatch,
298
    workspace_init,
299
    ep_size: int | None = None,
300
):
301
    set_random_seed(7)
302
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
303
    with set_current_vllm_config(vllm_config):
304
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
305

306
        score = torch.randn((m, e), device="cuda", dtype=torch.half)
307
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
308

309
310
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
311
312

        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
313
314
315
        triton_output = fused_experts(
            mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
        )
316

317
318
319
320
321
        if ep_size is not None:
            assert e % ep_size == 0, "Cannot distribute experts evenly"
            number_local_experts = e // ep_size
        else:
            number_local_experts = None
322

323
324
325
        cutlass_output = run_8_bit(
            mt, topk_weights, topk_ids, per_act_token, per_out_ch, number_local_experts
        )
326

bnellnm's avatar
bnellnm committed
327
328
        # Note 5.5 only needed for larger problem sizes, 5 works ok for
        # the rest.
329
330
331
        torch.testing.assert_close(
            triton_output, cutlass_output, atol=5.5e-2, rtol=1e-2
        )
332
333


334
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
335
336
337
338
339
340
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
341
342
343
344
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
345
def test_cutlass_moe_8_bit_cuda_graph(
346
347
348
349
350
351
352
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
353
    monkeypatch,
354
    workspace_init,
355
):
356
    set_random_seed(7)
357
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
358
    with set_current_vllm_config(vllm_config):
359
360
        dtype = torch.half

361
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch)
362
363

        score = torch.randn((m, e), device="cuda", dtype=dtype)
364
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
365

366
367
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
368
        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
369
370
371
        triton_output = fused_experts(
            mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids, quant_config=quant_config
        )
372
373
374
375

        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
376
377
378
            cutlass_output = run_8_bit(
                mt, topk_weights, topk_ids, per_act_token, per_out_ch
            )
379

380
381
382
383
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

384
        torch.testing.assert_close(triton_output, cutlass_output, atol=9e-2, rtol=1e-2)
385
386
387
388
389
390
391
392
393
394
395
396


@pytest.mark.parametrize("m", [64])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [16])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
397
398
399
400
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
401
402
403
404
405
406
407
408
409
def test_cutlass_moe_8_bit_EP(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
410
    monkeypatch,
411
    workspace_init,
412
):
413
    test_cutlass_moe_8_bit_no_graph(
414
415
416
417
418
419
420
421
422
423
        m,
        n,
        k,
        e,
        topk,
        per_act_token,
        per_out_channel,
        monkeypatch,
        workspace_init,
        ep_size,
424
    )
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440


LARGE_MNK_FACTORS = [
    (1, 8192, 5120, 31),
    (32768, 1024, 1024, 16),
    (65536, 512, 1024, 16),
]


@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
441
442
443
444
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
445
446
447
448
449
450
451
452
453
454
def test_cutlass_moe_8_bit_EP_large(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
    monkeypatch,
455
    workspace_init,
456
):
457
    test_cutlass_moe_8_bit_no_graph(
458
459
460
461
462
463
464
465
466
467
        m,
        n,
        k,
        e,
        topk,
        per_act_token,
        per_out_channel,
        monkeypatch,
        workspace_init,
        ep_size,
468
    )
469
470
471
472
473
474
475
476
477


@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
478
479
480
481
        current_platform.get_device_capability()
    ),
    reason="Grouped gemm is not supported on this GPU type.",
)
482
483
484
485
486
487
488
489
490
def test_run_cutlass_moe_fp8(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
491
    workspace_init,
492
):
493
    set_random_seed(7)
494
    with set_current_vllm_config(vllm_config):
495
496
497
        mt = MOETensors8Bit.make_moe_tensors_8bit(
            m, k, n, e, per_act_token, per_out_channel
        )
498
499

        score = torch.randn((m, e), device="cuda", dtype=torch.half)
500
        topk_weights, topk_ids, _ = fused_topk(mt.a, score, topk, renormalize=False)
501
502
503
504
505
506
507
        # we want to make sure there is at least one token that's generated in
        # this expert shard and at least one token that's NOT generated in this
        # expert shard
        topk_ids[0][0] = -1
        topk_ids[0][1] = 1

        workspace13_shape = (m * topk, max(2 * n, k))
508
509
        workspace2_shape = (m * topk, max(n, k))
        output_shape = (m, k)
510

511
512
513
514
515
516
        workspace13 = torch.empty(
            prod(workspace13_shape), device="cuda", dtype=mt.a.dtype
        )
        workspace2 = torch.empty(
            prod(workspace2_shape), device="cuda", dtype=mt.a.dtype
        )
517
518
519
520
521
522
523

        num_local_experts = e // ep_size
        start, end = 0, num_local_experts
        expert_map = [-1] * e
        expert_map[start:end] = list(range(num_local_experts))
        expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

524
525
526
527
        ab_strides1 = torch.full((e,), k, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e,), n, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
528

529
        activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
530
531
532
        a1q, a1q_scale = moe_kernel_quantize_input(
            mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
        )
533
534
        global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
        func = lambda output: run_cutlass_moe_fp8(
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            output,
            a1q,
            mt.w1_q,
            mt.w2_q,
            topk_ids,
            activation,
            global_num_experts,
            expert_map,
            mt.w1_scale,
            mt.w2_scale,
            a1q_scale,
            None,
            ab_strides1,
            ab_strides2,
            c_strides1,
            c_strides2,
            workspace13,
            workspace2,
            None,
            mt.a.dtype,
            per_act_token,
            per_out_channel,
            False,
            topk_weights,
        )
560
561

        workspace13.random_()
562
563
564
        output_random_workspace = torch.empty(
            output_shape, device="cuda", dtype=mt.a.dtype
        )
565
566
567
        func(output_random_workspace)

        workspace13.fill_(0)
568
569
570
        output_zero_workspace = torch.zeros(
            output_shape, device="cuda", dtype=mt.a.dtype
        )
571
572
        func(output_zero_workspace)

573
574
575
        torch.testing.assert_close(
            output_random_workspace, output_zero_workspace, atol=5e-3, rtol=1e-3
        )