test_cutlass_moe.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
import dataclasses
from typing import Optional

5
6
7
8
9
import pytest
import torch

from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
10
11
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
12
13
14
15
16
17
                                                            fused_topk)
from vllm.platforms import current_platform

NUM_EXPERTS = [40, 64]
TOP_KS = [6, 8]

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
MNK_FACTORS = [
    (2, 1024, 1024),
    (2, 1024, 1536),
    (2, 3072, 1024),
    (2, 3072, 1536),
    (64, 1024, 1024),
    (64, 1024, 1536),
    (64, 3072, 1024),
    (64, 3072, 1536),
    (224, 1024, 1024),
    (224, 1024, 1536),
    (224, 3072, 1024),
    (224, 3072, 1536),
]


@dataclasses.dataclass
class MOETensors:
    a: torch.Tensor
    w1: torch.Tensor
    w2: torch.Tensor
    ab_strides1: torch.Tensor
    c_strides1: torch.Tensor
    ab_strides2: torch.Tensor
    c_strides2: torch.Tensor

    @staticmethod
    def make_moe_tensors(m: int, k: int, n: int, e: int,
                         dtype: torch.dtype) -> "MOETensors":
        a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
        w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
        w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
        ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
        ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
        c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
        return MOETensors(a=a,
                          w1=w1,
                          w2=w2,
                          ab_strides1=ab_strides1,
                          c_strides1=c_strides1,
                          ab_strides2=ab_strides2,
                          c_strides2=c_strides2)


@dataclasses.dataclass
class MOETensors8Bit(MOETensors):
    # quantized
    a_q: Optional[torch.Tensor] = None  # a -> a_q
    w1_q: Optional[torch.Tensor] = None  # w1 -> w1_q
    w2_q: Optional[torch.Tensor] = None  # w2 -> w2_q
    a_scale: Optional[torch.Tensor] = None
    w1_scale: Optional[torch.Tensor] = None
    w2_scale: Optional[torch.Tensor] = None
    # dequantized
    a_d: Optional[torch.Tensor] = None  # a -> a_q -> a_d
    w1_d: Optional[torch.Tensor] = None  # w1 -> w1_q -> w1_d
    w2_d: Optional[torch.Tensor] = None  # w2 -> w2_q -> w2_d

    @staticmethod
    def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
                              per_act_token: bool,
                              per_out_channel: bool) -> "MOETensors8Bit":
        dtype = torch.half
        q_dtype = torch.float8_e4m3fn
83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)

        # a -> a_q, w1 -> w1_q, w2 -> w2_q
        n_b_scales = 2 * n if per_out_channel else 1
        k_b_scales = k if per_out_channel else 1
        # Get the right scale for tests.
        _, a_scale = ops.scaled_fp8_quant(
            moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
        a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
                                      a_scale,
                                      use_per_token_if_dynamic=per_act_token)
        w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
        w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)

        w1_scale = torch.empty((e, n_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)
        w2_scale = torch.empty((e, k_b_scales, 1),
                               device="cuda",
                               dtype=torch.float32)
        for expert in range(e):
            w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
                moe_tensors_fp16.w1[expert],
                use_per_token_if_dynamic=per_out_channel)
            w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
                moe_tensors_fp16.w2[expert],
                use_per_token_if_dynamic=per_out_channel)

        # a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
        a_d = a_q.float().mul(a_scale).to(dtype)
        w1_d = torch.empty_like(moe_tensors_fp16.w1)
        w2_d = torch.empty_like(moe_tensors_fp16.w2)
        for expert in range(e):
            w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
            w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()

        return MOETensors8Bit(a=moe_tensors_fp16.a,
                              w1=moe_tensors_fp16.w1,
                              w2=moe_tensors_fp16.w2,
                              ab_strides1=moe_tensors_fp16.ab_strides1,
                              c_strides1=moe_tensors_fp16.c_strides1,
                              ab_strides2=moe_tensors_fp16.ab_strides2,
                              c_strides2=moe_tensors_fp16.c_strides2,
                              a_q=a_q,
                              w1_q=w1_q,
                              w2_q=w2_q,
                              a_scale=a_scale,
                              w1_scale=w1_scale,
                              w2_scale=w2_scale,
                              a_d=a_d,
                              w1_d=w1_d,
                              w2_d=w2_d)


def run_with_expert_maps(num_experts: int, num_local_experts: int,
                         **cutlass_moe_kwargs):

    def slice_experts():
        slice_params = [
            "w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
            "c_strides2", "w1_scale", "w2_scale"
        ]
        full_tensors = {
            k: v
            for k, v in cutlass_moe_kwargs.items()
            if k in slice_params and k in cutlass_moe_kwargs
        }

        for i in range(0, num_experts, num_local_experts):
            s, e = i, i + num_local_experts

            # make expert map
            expert_map = [-1] * num_experts
            expert_map[s:e] = list(range(num_local_experts))
            expert_map = torch.tensor(expert_map,
                                      dtype=torch.int32,
                                      device="cuda")

            # update cutlass moe arg with expert_map
            cutlass_moe_kwargs["expert_map"] = expert_map
            # update cutlass moe arg tensors
            for k, t in full_tensors.items():
                cutlass_moe_kwargs[k] = t[s:e]

            yield cutlass_moe_kwargs

    out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
    for kwargs in slice_experts():
        out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)

    return out_tensor


def run_8_bit(moe_tensors: MOETensors8Bit,
              topk_weights: torch.Tensor,
              topk_ids: torch.Tensor,
              num_local_experts: Optional[int] = None) -> torch.Tensor:
    assert not any([
        t is None for t in [
            moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
            moe_tensors.w2_scale, moe_tensors.a_scale
        ]
    ])

    kwargs = {
        'a': moe_tensors.a,
        'w1_q': moe_tensors.w1_q.transpose(1, 2),  # type: ignore[union-attr]
        'w2_q': moe_tensors.w2_q.transpose(1, 2),  # type: ignore[union-attr]
        'topk_weights': topk_weights,
        'topk_ids_': topk_ids,
        'ab_strides1': moe_tensors.ab_strides1,
        'c_strides1': moe_tensors.c_strides1,
        'ab_strides2': moe_tensors.ab_strides2,
        'c_strides2': moe_tensors.c_strides2,
        'w1_scale': moe_tensors.w1_scale,
        'w2_scale': moe_tensors.w2_scale,
        'a1_scale': moe_tensors.a_scale
    }

    num_experts = moe_tensors.w1.size(0)
    with_ep = num_local_experts is not None or num_local_experts == num_experts
    if not with_ep:
        return cutlass_moe_fp8(**kwargs)

    assert num_local_experts is not None
    return run_with_expert_maps(
        num_experts,
        num_local_experts,  # type: ignore[arg-type]
        **kwargs)


@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
216
217
218
219
220
221
222
223
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
224
def test_cutlass_moe_8_bit_no_graph(
225
226
227
228
229
230
231
232
233
234
235
236
237
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
):
    current_platform.seed_everything(7)
    with set_current_vllm_config(
            VllmConfig(parallel_config=ParallelConfig(
                pipeline_parallel_size=1))):

238
239
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_ch)
240

241
242
243
244
245
        score = torch.randn((m, e), device="cuda", dtype=torch.half)
        topk_weights, topk_ids = fused_topk(mt.a,
                                            score,
                                            topk,
                                            renormalize=False)
246

247
248
249
250
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
        triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
                                      topk_ids)
251

252
        cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
253
254
255
256
257
258
259

        torch.testing.assert_close(triton_output,
                                   cutlass_output,
                                   atol=5e-2,
                                   rtol=1e-2)


260
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
261
262
263
264
265
266
267
268
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
269
def test_cutlass_moe_8_bit_cuda_graph(
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_ch: bool,
):
    current_platform.seed_everything(7)
    with set_current_vllm_config(
            VllmConfig(parallel_config=ParallelConfig(
                pipeline_parallel_size=1))):

        dtype = torch.half

285
286
        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_ch)
287
288

        score = torch.randn((m, e), device="cuda", dtype=dtype)
289
290
291
292
        topk_weights, topk_ids = fused_topk(mt.a,
                                            score,
                                            topk,
                                            renormalize=False)
293

294
295
296
297
        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
        triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
                                      topk_ids)
298
299
300
301

        stream = torch.cuda.Stream()
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph, stream=stream):
302
303
            cutlass_output = run_8_bit(mt, topk_weights, topk_ids)

304
305
306
307
308
309
310
311
        torch.cuda.synchronize()
        graph.replay()
        torch.cuda.synchronize()

        torch.testing.assert_close(triton_output,
                                   cutlass_output,
                                   atol=9e-2,
                                   rtol=1e-2)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364


@pytest.mark.parametrize("m", [64])
@pytest.mark.parametrize("n", [1024])
@pytest.mark.parametrize("k", [4096])
@pytest.mark.parametrize("e", [16])
@pytest.mark.parametrize("topk", [1, 8])
@pytest.mark.parametrize("per_act_token", [True])
@pytest.mark.parametrize("per_out_channel", [True])
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
@pytest.mark.skipif(
    (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
        current_platform.get_device_capability()),
    reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_moe_8_bit_EP(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    per_act_token: bool,
    per_out_channel: bool,
    ep_size: int,
):
    current_platform.seed_everything(7)
    with set_current_vllm_config(
            VllmConfig(parallel_config=ParallelConfig(
                pipeline_parallel_size=1))):

        mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
                                                  per_out_channel)

        score = torch.randn((m, e), device="cuda", dtype=torch.half)
        topk_weights, topk_ids = fused_topk(mt.a,
                                            score,
                                            topk,
                                            renormalize=False)

        # Note that we are using the dequantized versions of the tensors.
        # Using a, w1 and w2 directly results in minor output differences.
        triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
                                      topk_ids)

        assert e % ep_size == 0, "Cannot distribute experts evenly"
        cutlass_output = run_8_bit(mt,
                                   topk_weights,
                                   topk_ids,
                                   num_local_experts=e // ep_size)

        torch.testing.assert_close(triton_output,
                                   cutlass_output,
                                   atol=5e-2,
                                   rtol=1e-2)