test_moe.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
"""Tests for the MOE layers.

Run `pytest tests/kernels/test_moe.py`.
"""
import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

10
import vllm.model_executor.layers.fused_moe  # noqa
11
12
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
                                 torch_moe, torch_moe_single)
13
from vllm import _custom_ops as ops
14
from vllm.model_executor.layers.fused_moe import fused_moe
15
16
from vllm.model_executor.layers.fused_moe.fused_moe import (
    fused_topk, moe_align_block_size)
17
18
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
    marlin_quantize)
19
from vllm.model_executor.models.mixtral import MixtralMoE
20
from vllm.platforms import current_platform
21
from vllm.scalar_type import scalar_types
22

23
24
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
25

26
27
28

@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048])
29
@pytest.mark.parametrize("k", [128, 511, 1024])
30
31
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
32
33
34
35
36
37
38
39
40
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    dtype: torch.dtype,
):
41
42
43
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
44

45
    score = torch.randn((m, e), device="cuda", dtype=dtype)
46
47
    triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
    torch_output = torch_moe(a, w1, w2, score, topk)
48
    torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
49
50
51
52
53
54


@pytest.mark.parametrize("dtype",
                         [torch.float32, torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
55
56
    """Make sure our Mixtral MoE implementation agrees with the one from
    huggingface."""
57
58
59
60
61
62
63
64
65
66
67

    # Instantiate our and huggingface's MoE blocks
    config = MixtralConfig()
    hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
    vllm_moe = MixtralMoE(
        num_experts=config.num_local_experts,
        top_k=config.num_experts_per_tok,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        params_dtype=dtype,
        tp_size=1,
68
    ).cuda()
69
70

    # Load the weights
71
    vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
72
73
74
    for i in range(config.num_local_experts):
        weights = (hf_moe.experts[i].w1.weight.data,
                   hf_moe.experts[i].w3.weight.data)
75
76
        vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
        vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
77
78

    # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
79
80
81
    hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
    # vLLM uses 1D query [num_tokens, hidden_dim]
    vllm_inputs = hf_inputs.flatten(0, 1)
82
83

    # Run forward passes for both MoE blocks
84
85
    hf_states, _ = hf_moe.forward(hf_inputs)
    vllm_states = vllm_moe.forward(vllm_inputs)
86
87
88
89
90
91
92

    mixtral_moe_tol = {
        torch.float32: 1e-3,
        torch.float16: 1e-3,
        torch.bfloat16: 1e-2,
    }

93
94
95
96
    torch.testing.assert_close(hf_states.flatten(0, 1),
                               vllm_states,
                               rtol=mixtral_moe_tol[dtype],
                               atol=mixtral_moe_tol[dtype])
97
98


99
100
101
102
103
104
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 2048])
@pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("group_size", [-1, 32, 128])
105
@pytest.mark.parametrize("act_order", [True, False])
106
@pytest.mark.parametrize("num_bits", [4, 8])
107
@pytest.mark.parametrize("is_k_full", [True, False])
108
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
109
110
111
112
113
114
115
116
def test_fused_marlin_moe(
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    group_size: int,
    act_order: bool,
117
    num_bits: int,
118
    is_k_full: bool,
119
):
120
    current_platform.seed_everything(7)
121
122
123
124
125
126
127

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size in (k, n):
            return
128
129
130
    else:
        if not is_k_full:
            return
131

132
133
    quant_type = (scalar_types.uint4b8
                  if num_bits == 4 else scalar_types.uint8b128)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    dtype = torch.float16
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
    w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

    w_ref1_l = []
    qweight1_l = []
    scales1_l = []
    g_idx1_l = []
    sort_indices1_l = []

    for i in range(w1.shape[0]):
        test_perm = torch.randperm(k)
        w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
            w1[i].transpose(1, 0), quant_type, group_size, act_order,
            test_perm)
        w_ref1_l.append(w_ref1)
        qweight1_l.append(qweight1)
        scales1_l.append(scales1)
        g_idx1_l.append(g_idx1)
        sort_indices1_l.append(sort_indices1)

    w_ref1 = stack_and_dev(w_ref1_l)
    qweight1 = stack_and_dev(qweight1_l).contiguous()
    scales1 = stack_and_dev(scales1_l)
    g_idx1 = stack_and_dev(g_idx1_l)
    sort_indices1 = stack_and_dev(sort_indices1_l)

    w_ref2_l = []
    qweight2_l = []
    scales2_l = []
    g_idx2_l = []
    sort_indices2_l = []

    for i in range(w2.shape[0]):
        test_perm = torch.randperm(n)
        w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
            w2[i].transpose(1, 0), quant_type, group_size, act_order,
            test_perm)
        w_ref2_l.append(w_ref2)
        qweight2_l.append(qweight2)
        scales2_l.append(scales2)
        g_idx2_l.append(g_idx2)
        sort_indices2_l.append(sort_indices2)

    w_ref2 = stack_and_dev(w_ref2_l)
    qweight2 = stack_and_dev(qweight2_l).contiguous()
    scales2 = stack_and_dev(scales2_l)
    g_idx2 = stack_and_dev(g_idx2_l)
    sort_indices2 = stack_and_dev(sort_indices2_l)

    score = torch.randn((m, e), device="cuda", dtype=dtype)

    topk_weights, topk_ids = fused_topk(a, score, topk, False)

    triton_output = fused_moe(
        a,
        w_ref1.transpose(1, 2).contiguous(),
        w_ref2.transpose(1, 2).contiguous(),
        score,
        topk,
        renormalize=False,
    )
197
    marlin_output = torch.ops.vllm.fused_marlin_moe(
198
199
200
        a,
        qweight1,
        qweight2,
201
202
        scales1,
        scales2,
203
204
205
        score,
        topk_weights,
        topk_ids,
206
207
208
209
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
210
        num_bits=num_bits,
211
        is_k_full=is_k_full,
212
213
214
215
    )

    assert compute_max_diff(marlin_output, triton_output) < 4e-2

216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    if ops.supports_moe_ops:
        token_expert_indicies = torch.empty(m,
                                            topk,
                                            dtype=torch.int32,
                                            device=a.device)

        opcheck(torch.ops._moe_C.topk_softmax, (
            topk_weights,
            topk_ids,
            token_expert_indicies,
            score.float(),
        ))

        block_size_m = 4

        sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m,
                                                      e)

        max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16
        workspace = torch.zeros(max_workspace_size,
                                dtype=torch.int,
                                device="cuda",
                                requires_grad=False)

240
241
242
243
        zp = torch.empty((0, 0),
                         dtype=dtype,
                         device="cuda",
                         requires_grad=False)
244
245
        opcheck(torch.ops._moe_C.marlin_gemm_moe,
                (a, qweight1, sorted_token_ids, topk_weights, topk_ids,
246
247
                 scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
                 m, 2 * n, k, True, e, topk, block_size_m, True, False))
248

249
250
251
252
253
254

@pytest.mark.skip("This test is here for the sake of debugging, "
                  "don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
255
@pytest.mark.parametrize("e", [8, 64])
256
257
258
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
259
@pytest.mark.parametrize("num_bits", [4, 8])
260
@pytest.mark.parametrize("is_k_full", [True, False])
261
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
262
def test_single_marlin_moe_multiply(
263
264
265
266
267
268
269
    m: int,
    n: int,
    k: int,
    e: int,
    topk: int,
    group_size: int,
    act_order: bool,
270
    num_bits: int,
271
    is_k_full: bool,
272
273
274
275
276
277
278
279
):

    # Filter act_order
    if act_order:
        if group_size == -1:
            return
        if group_size == k:
            return
280
281
282
    else:
        if not is_k_full:
            return
283

284
285
    quant_type = (scalar_types.uint4b8
                  if num_bits == 4 else scalar_types.uint8b128)
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    dtype = torch.float16
    a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
    w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

    w_ref_l = []
    qweights_l = []
    scales_l = []
    g_idx_l = []
    sort_indices_l = []

    for i in range(w.shape[0]):
        test_perm = torch.randperm(k)
        w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
            w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
        w_ref_l.append(w_ref)
        qweights_l.append(qweight)
        scales_l.append(scales)
        g_idx_l.append(g_idx)
        sort_indices_l.append(sort_indices)

    w_ref = stack_and_dev(w_ref_l)
    qweight = stack_and_dev(qweights_l).contiguous()
    scales = stack_and_dev(scales_l)
    g_idx = stack_and_dev(g_idx_l)
    sort_indices = stack_and_dev(sort_indices_l)

    score = torch.randn((m, e), device="cuda", dtype=dtype)
313
    marlin_output = torch.ops.vllm.single_marlin_moe(
314
315
316
317
318
319
        a,
        qweight,
        scales,
        score,
        topk,
        renormalize=False,
320
321
        g_idx=g_idx,
        sort_indices=sort_indices,
322
323
324
        num_bits=num_bits,
        is_k_full=is_k_full,
    )
325

326
327
328
    torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

    assert compute_max_diff(marlin_output, torch_output) < 1e-2
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351


def test_moe_align_block_size_opcheck():
    num_experts = 4
    block_size = 4
    topk_ids = torch.randint(0,
                             num_experts, (3, 4),
                             dtype=torch.int32,
                             device='cuda')

    max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
    sorted_ids = torch.empty((max_num_tokens_padded, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    sorted_ids.fill_(topk_ids.numel())
    max_num_m_blocks = max_num_tokens_padded // block_size
    expert_ids = torch.empty((max_num_m_blocks, ),
                             dtype=torch.int32,
                             device=topk_ids.device)
    num_tokens_post_pad = torch.empty((1),
                                      dtype=torch.int32,
                                      device=topk_ids.device)

352
    opcheck(torch.ops._moe_C.moe_align_block_size,
353
354
            (topk_ids, num_experts, block_size, sorted_ids, expert_ids,
             num_tokens_post_pad))