"tests/tokenizers_/test_detokenize.py" did not exist on "49782fcb769eb4f04a3cf5179c1e6c13ab633ce1"
test_cutlass_moe.py 19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import dataclasses
5
from math import prod
6
7
from typing import Optional

8
9
10
11
12
import pytest
import torch

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
13
14
from vllm.model_executor.layers.fused_moe.config import (
    FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
15
16
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
    cutlass_moe_fp8, run_cutlass_moe_fp8)
17
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
18
                                                            fused_topk)
19
20
from vllm.model_executor.layers.fused_moe.utils import (
    moe_kernel_quantize_input)
21
22
23
24
25
from vllm.platforms import current_platform

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]

26
27
28
29
30
MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (2, 3072, 1536),
31
    (7, 3072, 1536),
32
33
34
35
36
37
38
39
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (64, 3072, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
    (224, 3072, 1024),
    (224, 3072, 1536),
40
41
42
43
    (32768, 1024, 1024),
    # These sizes trigger wrong answers.
    #(7232, 2048, 5120),
    #(40000, 2048, 5120),
44
45
]

46
47
48
49
50
vllm_config = VllmConfig(parallel_config=ParallelConfig(
    pipeline_parallel_size=1))
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

@dataclasses.dataclass
class MOETensors:
    a: torch.Tensor
    w1: torch.Tensor
    w2: torch.Tensor
    ab_strides1: torch.Tensor
    c_strides1: torch.Tensor
    ab_strides2: torch.Tensor
    c_strides2: torch.Tensor

    @staticmethod
    def make_moe_tensors(m: int, k: int, n: int, e: int,
                         dtype: torch.dtype) -> "MOETensors":
        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
        ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        return MOETensors(a=a,
                          w1=w1,
                          w2=w2,
                          ab_strides1=ab_strides1,
                          c_strides1=c_strides1,
                          ab_strides2=ab_strides2,
                          c_strides2=c_strides2)


@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
    # quantized
    a_q: Optional[torch.Tensor] = None  # a -> a_q
    w1_q: Optional[torch.Tensor] = None  # w1 -> w1_q
    w2_q: Optional[torch.Tensor] = None  # w2 -> w2_q
    a_scale: Optional[torch.Tensor] = None
    w1_scale: Optional[torch.Tensor] = None
    w2_scale: Optional[torch.Tensor] = None
    # dequantized
    a_d: Optional[torch.Tensor] = None  # a -> a_q -> a_d
    w1_d: Optional[torch.Tensor] = None  # w1 -> w1_q -> w1_d
    w2_d: Optional[torch.Tensor] = None  # w2 -> w2_q -> w2_d

    @staticmethod
    def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
                              per_act_token: bool,
                              per_out_channel: bool) -> "MOETensors8Bit":
        dtype = torch.half
        q_dtype = torch.float8_e4m3fn
101

102
103
104
105
106
107
        moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)

        # a -> a_q, w1 -> w1_q, w2 -> w2_q
        n_b_scales = 2 * n if per_out_channel else 1
        k_b_scales = k if per_out_channel else 1
        # Get the right scale for tests.
bnellnm's avatar
bnellnm committed
108
109
110
        a_q, a_scale = ops.scaled_fp8_quant(
            moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
        w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)

        w1_scale = torch.empty((e, n_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)
        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
                moe_tensors_fp16.w1[expert],
                use_per_token_if_dynamic=per_out_channel)
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
                moe_tensors_fp16.w2[expert],
                use_per_token_if_dynamic=per_out_channel)

        # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
        a_d = a_q.float().mul(a_scale).to(dtype)
        w1_d = torch.empty_like(moe_tensors_fp16.w1)
        w2_d = torch.empty_like(moe_tensors_fp16.w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        return MOETensors8Bit(a=moe_tensors_fp16.a,
                              w1=moe_tensors_fp16.w1,
                              w2=moe_tensors_fp16.w2,
                              ab_strides1=moe_tensors_fp16.ab_strides1,
                              c_strides1=moe_tensors_fp16.c_strides1,
                              ab_strides2=moe_tensors_fp16.ab_strides2,
                              c_strides2=moe_tensors_fp16.c_strides2,
                              a_q=a_q,
                              w1_q=w1_q,
                              w2_q=w2_q,
                              a_scale=a_scale,
                              w1_scale=w1_scale,
                              w2_scale=w2_scale,
                              a_d=a_d,
                              w1_d=w1_d,
                              w2_d=w2_d)


def run_with_expert_maps(num_experts: int, num_local_experts: int,
                         **cutlass_moe_kwargs):

    def slice_experts():
        slice_params = [
            "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
160
            "c_strides2"
161
162
163
164
165
166
167
        ]
        full_tensors = {
            k: v
            for k, v in cutlass_moe_kwargs.items()
            if k in slice_params and k in cutlass_moe_kwargs
        }

168
169
        quant_config = cutlass_moe_kwargs["quant_config"]

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        for i in range(0, num_experts, num_local_experts):
            s, e = i, i + num_local_experts

            # make expert map
            expert_map = [-1] * num_experts
            expert_map[s:e] = list(range(num_local_experts))
            expert_map = torch.tensor(expert_map,
                                      dtype=torch.int32,
                                      device="cuda")

            # update cutlass moe arg with expert_map
            cutlass_moe_kwargs["expert_map"] = expert_map
            # update cutlass moe arg tensors
            for k, t in full_tensors.items():
                cutlass_moe_kwargs[k] = t[s:e]

186
187
188
189
190
191
            new_quant_config = copy.deepcopy(quant_config)
            new_quant_config._w1.scale = quant_config.w1_scale[s:e]
            new_quant_config._w2.scale = quant_config.w2_scale[s:e]

            cutlass_moe_kwargs["quant_config"] = new_quant_config

192
193
194
195
196
197
198
199
200
201
202
203
            yield cutlass_moe_kwargs

    out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
    for kwargs in slice_experts():
        out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)

    return out_tensor


def run_8_bit(moe_tensors: MOETensors8Bit,
              topk_weights: torch.Tensor,
              topk_ids: torch.Tensor,
bnellnm's avatar
bnellnm committed
204
              per_act_token: bool,
205
              per_out_ch: bool,
206
207
208
209
210
211
212
213
              num_local_experts: Optional[int] = None) -> torch.Tensor:
    assert not any([
        t is None for t in [
            moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
            moe_tensors.w2_scale, moe_tensors.a_scale
        ]
    ])

214
215
216
217
218
219
220
221
222
223
    quant_config = fp8_w8a8_moe_quant_config(
        w1_scale=moe_tensors.w1_scale,
        w2_scale=moe_tensors.w2_scale,
        per_act_token_quant=per_act_token,
        per_out_ch_quant=per_out_ch,
        # Set to moe_tensors.a_scale iff static scales + per tensor.
        # This is not currently being tested.
        a1_scale=None,
    )

224
225
    kwargs = {
        'a': moe_tensors.a,
226
227
        'w1_q': moe_tensors.w1_q,  # type: ignore[union-attr]
        'w2_q': moe_tensors.w2_q,  # type: ignore[union-attr]
228
        'topk_weights': topk_weights,
229
        'topk_ids': topk_ids,
230
231
232
233
        'ab_strides1': moe_tensors.ab_strides1,
        'ab_strides2': moe_tensors.ab_strides2,
        'c_strides1': moe_tensors.c_strides1,
        'c_strides2': moe_tensors.c_strides2,
234
        'quant_config': quant_config,
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    }

    num_experts = moe_tensors.w1.size(0)
    with_ep = num_local_experts is not None or num_local_experts == num_experts
    if not with_ep:
        return cutlass_moe_fp8(**kwargs)

    assert num_local_experts is not None
    return run_with_expert_maps(
        num_experts,
        num_local_experts,  # type: ignore[arg-type]
        **kwargs)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
250
251
252
253
254
255
256
257
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
258
def test_cutlass_moe_8_bit_no_graph(
259
260
261
262
263
264
265
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
266
    monkeypatch,
267
    ep_size: Optional[int] = None,
268
269
):
    current_platform.seed_everything(7)
270
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
271
    with set_current_vllm_config(vllm_config):
272
273
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_ch)
274

275
        score = torch.randn((m, e), device="cuda", dtype=torch.half)
276
277
278
279
        topk_weights, topk_ids, _ = fused_topk(mt.a,
                                               score,
                                               topk,
                                               renormalize=False)
280

281
282
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
283
284
285
286
287
288
289
290

        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
        triton_output = fused_experts(mt.a_d,
                                      mt.w1_d,
                                      mt.w2_d,
                                      topk_weights,
                                      topk_ids,
                                      quant_config=quant_config)
291

292
293
294
295
296
        if ep_size is not None:
            assert e % ep_size == 0, "Cannot distribute experts evenly"
            number_local_experts = e // ep_size
        else:
            number_local_experts = None
297

298
        cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
299
                                   per_out_ch, number_local_experts)
300

bnellnm's avatar
bnellnm committed
301
302
        # Note 5.5 only needed for larger problem sizes, 5 works ok for
        # the rest.
303
304
        torch.testing.assert_close(triton_output,
                                   cutlass_output,
bnellnm's avatar
bnellnm committed
305
                                   atol=5.5e-2,
306
307
308
                                   rtol=1e-2)


309
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
310
311
312
313
314
315
316
317
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
318
def test_cutlass_moe_8_bit_cuda_graph(
319
320
321
322
323
324
325
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
326
    monkeypatch,
327
328
):
    current_platform.seed_everything(7)
329
    monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
330
    with set_current_vllm_config(vllm_config):
331
332
        dtype = torch.half

333
334
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_ch)
335
336

        score = torch.randn((m, e), device="cuda", dtype=dtype)
337
338
339
340
        topk_weights, topk_ids, _ = fused_topk(mt.a,
                                               score,
                                               topk,
                                               renormalize=False)
341

342
343
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
344
345
346
347
348
349
350
        quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
        triton_output = fused_experts(mt.a_d,
                                      mt.w1_d,
                                      mt.w2_d,
                                      topk_weights,
                                      topk_ids,
                                      quant_config=quant_config)
351
352
353
354

        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
bnellnm's avatar
bnellnm committed
355
            cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
356
                                       per_act_token, per_out_ch)
357

358
359
360
361
362
363
364
365
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

        torch.testing.assert_close(triton_output,
                                   cutlass_output,
                                   atol=9e-2,
                                   rtol=1e-2)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388


@pytest.mark.parametrize("m", [64])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [16])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
389
    monkeypatch,
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
):
    test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
                                    per_out_channel, monkeypatch, ep_size)


LARGE_MNK_FACTORS = [
    (1, 8192, 5120, 31),
    (32768, 1024, 1024, 16),
    (65536, 512, 1024, 16),
]


@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS)
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP_large(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
    monkeypatch,
):
    test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token,
                                    per_out_channel, monkeypatch, ep_size)


@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)])
@pytest.mark.parametrize("e", [128])
@pytest.mark.parametrize("per_act_token", [False])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [8])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
def test_run_cutlass_moe_fp8(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
444
445
):
    current_platform.seed_everything(7)
446
    with set_current_vllm_config(vllm_config):
447
448
449
450
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_channel)

        score = torch.randn((m, e), device="cuda", dtype=torch.half)
451
452
453
454
        topk_weights, topk_ids, _ = fused_topk(mt.a,
                                               score,
                                               topk,
                                               renormalize=False)
455
456
457
458
459
460
461
        # we want to make sure there is at least one token that's generated in
        # this expert shard and at least one token that's NOT generated in this
        # expert shard
        topk_ids[0][0] = -1
        topk_ids[0][1] = 1

        workspace13_shape = (m * topk, max(2 * n, k))
462
463
        workspace2_shape = (m * topk, max(n, k))
        output_shape = (m, k)
464
465
466
467
468
469
470
471
472
473
474
475
476
477

        workspace13 = torch.empty(prod(workspace13_shape),
                                  device="cuda",
                                  dtype=mt.a.dtype)
        workspace2 = torch.empty(prod(workspace2_shape),
                                 device="cuda",
                                 dtype=mt.a.dtype)

        num_local_experts = e // ep_size
        start, end = 0, num_local_experts
        expert_map = [-1] * e
        expert_map[start:end] = list(range(num_local_experts))
        expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")

478
479
480
481
482
        ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)

483
484
485
486
487
488
489
490
        activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
        a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
                                                   torch.float8_e4m3fn,
                                                   per_act_token)
        global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0)
        func = lambda output: run_cutlass_moe_fp8(
            output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
            global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
491
492
493
            a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
            workspace13, workspace2, None, mt.a.dtype, per_act_token,
            per_out_channel, False, topk_weights)
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510

        workspace13.random_()
        output_random_workspace = torch.empty(output_shape,
                                              device="cuda",
                                              dtype=mt.a.dtype)
        func(output_random_workspace)

        workspace13.fill_(0)
        output_zero_workspace = torch.zeros(output_shape,
                                            device="cuda",
                                            dtype=mt.a.dtype)
        func(output_zero_workspace)

        torch.testing.assert_close(output_random_workspace,
                                   output_zero_workspace,
                                   atol=5e-3,
                                   rtol=1e-3)